{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dab0095f-35c4-474a-b354-9305e4e756c4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# =============================================================================\n",
    "# IMPORTS\n",
    "# =============================================================================\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torch.utils.data import DataLoader, random_split, Subset\n",
    "from torchvision.datasets import CIFAR100\n",
    "import torchvision.transforms as transforms\n",
    "from torchvision.models import resnet34, ResNet34_Weights\n",
    "import numpy as np\n",
    "import os\n",
    "import shutil\n",
    "from tqdm.notebook import tqdm\n",
    "from sklearn.neighbors import KNeighborsClassifier\n",
    "from sklearn.metrics import accuracy_score\n",
    "import seaborn as sns\n",
    "from sklearn.decomposition import PCA\n",
    "import matplotlib.pyplot as plt\n",
    "import random\n",
    "import csv"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "87321cc0-dcf9-4a0a-bcd9-169a99531cff",
   "metadata": {},
   "outputs": [],
   "source": [
    "# =============================================================================\n",
    "# CONFIGURATION\n",
    "# =============================================================================\n",
    "\n",
    "class Config:\n",
    "    \"\"\"Configuration class for all hyperparameters.\"\"\"\n",
    "    # Data parameters\n",
    "    data_dir = './data'\n",
    "\n",
    "    # Model parameters\n",
    "    resnet_embedding_dim = 512                 # ResNet34 output dimension\n",
    "    mlp_hidden_dim = 1024                      # Monotonic MLP hidden dimension\n",
    "    mlp_num_layers = 1                         # Monotonic MLP number of hidden layers\n",
    "\n",
    "    # Training parameters\n",
    "    batch_size = 256                           # Mini-batch size\n",
    "    lr_resnet = 2e-4                           # Learning rate for the ResNet encoder\n",
    "    lr_mlp = 2e-4                              # Learning rate for the monotonic MLP head\n",
    "    seed = 42                                  # Random number generator seed\n",
    "    temperature = 0.1\n",
    "    is_ablation = True                        # Set to True to bypass the monotonic MLP head for the ablation study training run.\n",
    "\n",
    "    # Evaluation parameters\n",
    "    knn_k = 5                                  # Number of nearest neighbors in the k nearest neighbors (knn) algorithm\n",
    "    \n",
    "    # System parameters\n",
    "    base_results_dir = './results_CIFAR100_kNN_monitoring'                   # Name of the base directory to save checkpoints and other results\n",
    "    device = 'cuda' if torch.cuda.is_available() else 'cpu'                  # Specify whether using GPU or CPU\n",
    "\n",
    "args = Config()                                # Create an instance of the Config class to pass arguments to functions as needed"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c9676548-12d2-4fd1-9fce-71e8859e473e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# =============================================================================\n",
    "# SETUP REPRODUCIBILITY AND DIRECTORIES\n",
    "# =============================================================================\n",
    "torch.manual_seed(args.seed)\n",
    "np.random.seed(args.seed)\n",
    "\n",
    "run_name = (\n",
    "    f\"temp_{args.temperature}_bs_{args.batch_size}_lr_encoder_{args.lr_resnet}_num_mlp_layers_{args.mlp_num_layers}\"\n",
    "    f\"{f'_lr_head_{args.lr_mlp}' if not args.is_ablation else ''}\"        # Conditionally add head learning rate\n",
    "    f\"{'_ablation' if args.is_ablation else ''}\"       \n",
    ")\n",
    "loading_dir = os.path.join(args.base_results_dir, run_name)\n",
    "os.makedirs(loading_dir, exist_ok=True)\n",
    "print(f\"Best model will be loaded from this directory: {loading_dir}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2ee61a76-82cc-43ef-8f7b-80ec19b4aa9d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# =============================================================================\n",
    "# TRANSFORMS\n",
    "# =============================================================================\n",
    "\n",
    "# Use ImageNet's standard normalization statistics for the pre-trained ResNet\n",
    "imagenet_mean = [0.485, 0.456, 0.406]\n",
    "imagenet_std = [0.229, 0.224, 0.225]\n",
    "\n",
    "eval_transform = transforms.Compose([\n",
    "    transforms.ToTensor(),                                                           # Converts image to a Pytorch tensor\n",
    "    transforms.Normalize(mean=imagenet_mean, std=imagenet_std)                       # Normalizes by subtracting mean and dividing by std\n",
    "])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e95e61f4-d62f-44c7-8694-66b65222bb7a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# =============================================================================\n",
    "# MODEL CLASSES\n",
    "# =============================================================================\n",
    "\n",
    "class MonotonicMLP(nn.Module):\n",
    "    \"\"\"Monotonic MLP with positive weights to ensure an increasing function.\"\"\"\n",
    "    def __init__(self, input_dim, hidden_dim, num_layers):\n",
    "        super().__init__()\n",
    "        self.num_layers = num_layers\n",
    "        layers = []\n",
    "        current_dim = input_dim\n",
    "        for _ in range(num_layers):\n",
    "            layers.append(nn.Linear(current_dim, hidden_dim))\n",
    "            current_dim = hidden_dim\n",
    "        self.hidden_layers = nn.ModuleList(layers)\n",
    "        self.output_layer = nn.Linear(current_dim, input_dim)\n",
    "\n",
    "    def forward(self, x):                                   # Implement monotonicity constraints via squared weights and non-decreasing activations\n",
    "        z = x   \n",
    "        for layer in self.hidden_layers:\n",
    "            positive_weight = layer.weight**2               # Using squared weights to ensure positivity in the hidden layers during the forward pass\n",
    "            z = F.leaky_relu(F.linear(z, positive_weight, layer.bias))  # Uses the non-decreasing Leaky Relu activation function\n",
    "        positive_weight_out = self.output_layer.weight**2\n",
    "        z = F.linear(z, positive_weight_out, self.output_layer.bias)    # Squared weights are used for positivity in the output as well\n",
    "        return z\n",
    "\n",
    "        \n",
    "class MonoCon_ResNet(nn.Module):\n",
    "    \"\"\"Metric learning model using a pre-trained ResNet and a Monotonic MLP head.\"\"\"\n",
    "    def __init__(self, resnet_embedding_dim, mlp_hidden_dim, mlp_num_layers, is_ablation=False):\n",
    "        super().__init__()\n",
    "        self.is_ablation = is_ablation                                       # Flag indicating whether to run the full model or ablation study\n",
    "        self.encoder = resnet34(weights=ResNet34_Weights.IMAGENET1K_V1)      # Initialize ResNet34 encoder with pre-trained weights\n",
    "        self.encoder.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) # Modify convolutional layer for low res CIFAR-100 images\n",
    "        self.encoder.maxpool = nn.Identity()             # Remove max pooling layer to prevent over-compressing low res CIFAR-100 images\n",
    "        self.encoder.fc = nn.Identity()                  # Remove the classification layer at the end\n",
    "        if not self.is_ablation:\n",
    "            self.monotonic_mlp = MonotonicMLP(resnet_embedding_dim, mlp_hidden_dim, mlp_num_layers) # Ablation study is run without the MLP head\n",
    "\n",
    "    def forward(self, x):\n",
    "        y_encoder = self.encoder(x)\n",
    "        y = y_encoder if self.is_ablation else self.monotonic_mlp(y_encoder)    \n",
    "        return y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0bb09bf3-827f-43f2-92fe-f231c641c7db",
   "metadata": {},
   "outputs": [],
   "source": [
    "# =============================================================================\n",
    "# FUNCTIONS FOR COMPUTING EMBEDDINGS AND VALIDATION METRICS\n",
    "# =============================================================================\n",
    "\n",
    "def calculate_recall_at_k(query_embed, query_labels, gallery_embed, gallery_labels, k):\n",
    "    \"\"\"Calculates Recall@K.\"\"\"\n",
    "    sim_matrix = torch.matmul(query_embed, gallery_embed.T)           # Computes the cosine similarity matrix between query and gallery embeddings\n",
    "    \n",
    "    # Extract indices and labels of the top k matches (largest cosine similarity) to the query\n",
    "    top_k_indices = torch.topk(sim_matrix, k=k, dim=1).indices        \n",
    "    top_k_labels = gallery_labels[top_k_indices]\n",
    "\n",
    "    # Compute the fraction of queries for which a match was found within the top k entries\n",
    "    correct_recalls = (top_k_labels == query_labels.unsqueeze(1)).any(dim=1)\n",
    "    recall_at_k = correct_recalls.float().mean().item()\n",
    "    return recall_at_k\n",
    "\n",
    "\n",
    "def compute_embeddings(model, dataloader, device):\n",
    "    model.eval()\n",
    "    y_list, labels_list = [], []\n",
    "    pbar = tqdm(dataloader, desc=\"Computing embeddings\", unit=\"batch\", leave=False)\n",
    "    for images, labels in pbar:\n",
    "        images = images.to(device)\n",
    "        labels_list.append(labels.cpu())\n",
    "        with torch.no_grad():\n",
    "            y_batch = model(images)\n",
    "        y_list.append(y_batch.cpu())\n",
    "    y_embed = torch.cat(y_list, dim=0)\n",
    "    all_labels = torch.cat(labels_list, dim=0)\n",
    "    return y_embed, all_labels\n",
    "\n",
    "\n",
    "def evaluate_embeddings(train_embed, train_labels, test_embed, test_labels, k):\n",
    "    knn = KNeighborsClassifier(n_neighbors=k, n_jobs=-1)\n",
    "    knn.fit(train_embed, train_labels)\n",
    "    preds = knn.predict(test_embed)\n",
    "    accuracy = accuracy_score(test_labels, preds)\n",
    "    print(f\"k-NN Classification Accuracy: {accuracy * 100:.2f}%\")\n",
    "    return accuracy\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2cbb2f7d-045f-406e-b26a-db25f9d6265c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# =============================================================================\n",
    "# DATASETS AND DATALOADERS\n",
    "# =============================================================================\n",
    "# Create train and test datasets\n",
    "train_dataset = CIFAR100(root=args.data_dir, train=True, download=True, transform=eval_transform)   # Tensor-ified and normalized train dataset\n",
    "test_dataset = CIFAR100(root=args.data_dir, train=False, download=True, transform=eval_transform)   # Tensor-ified and normalized test dataset\n",
    "\n",
    "num_train = len(train_dataset)\n",
    "indices = list(range(num_train))\n",
    "split = int(0.9 * num_train)\n",
    "print(split)\n",
    "random.seed(args.seed)\n",
    "random.shuffle(indices)\n",
    "subset_indices = indices[split:]\n",
    "\n",
    "train_data_subset = Subset(train_dataset,subset_indices)\n",
    "print(len(train_data_subset))\n",
    "\n",
    "# 4. Create train and test dataloaders\n",
    "train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=4)\n",
    "test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=4)\n",
    "train_subset_loader =  DataLoader(train_data_subset, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=4)\n",
    "\n",
    "print(\"Data loaded successfully.\");"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "178e824f-46a2-4398-b760-2a058bbf9ae8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# =============================================================================\n",
    "# LOAD THE BEST MODEL\n",
    "# =============================================================================\n",
    "model = MonoCon_ResNet(\n",
    "    resnet_embedding_dim=args.resnet_embedding_dim,\n",
    "    mlp_hidden_dim=args.mlp_hidden_dim,\n",
    "    mlp_num_layers=args.mlp_num_layers,\n",
    "    is_ablation=args.is_ablation,\n",
    ").to(args.device)\n",
    "\n",
    "best_model_path = os.path.join(loading_dir, 'best_model.pth')\n",
    "\n",
    "if not os.path.exists(best_model_path):\n",
    "    raise FileNotFoundError(f\"Model file not found at {best_model_path}.\")\n",
    "model.load_state_dict(torch.load(best_model_path, map_location=args.device))\n",
    "print(\"Model loaded successfully.\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7043a651-bc4c-4a17-9303-10c89a2c48e5",
   "metadata": {},
   "outputs": [],
   "source": [
    "# =============================================================================\n",
    "# COMPUTE AND NORMALIZE TRAIN AND TEST EMBEDDINGS\n",
    "# =============================================================================\n",
    "\n",
    "# Compute embeddings\n",
    "train_embed, train_labels = compute_embeddings(model, train_loader, args.device)\n",
    "test_embed, test_labels = compute_embeddings(model, test_loader, args.device)\n",
    "train_subset_embed, train_subset_labels = compute_embeddings(model, train_subset_loader, args.device)\n",
    "\n",
    "######### Normalize embeddings #####################\n",
    "# Uncomment tensor shape to make make sure the correct dimension is used to normalize embeddings\n",
    "# print(train_embed.shape)\n",
    "# print(test_embed.shape)\n",
    "train_embed_norm = F.normalize(train_embed, p=2, dim=1)\n",
    "test_embed_norm = F.normalize(test_embed, p=2, dim=1)\n",
    "train_subset_embed_norm = F.normalize(train_subset_embed, p=2, dim=1)\n",
    "print(train_subset_embed_norm.shape)\n",
    "###################################################"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d6b31291-a9a9-4ef4-9cbc-0f8c96e6791b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# =============================================================================\n",
    "# COMPUTE Recall@k AS A FUNCTION OF k AND SAVE NUMPY ARRAY\n",
    "# =============================================================================\n",
    "k_list = list(range(1,21))\n",
    "recall_at_k_list = np.zeros(len(k_list))\n",
    "filename  = f\"recall_at_k_vs_k{'_ablation' if args.is_ablation else ''}.npy\"\n",
    "for j in k_list:\n",
    "    recall_at_k_list[j-1] = calculate_recall_at_k(test_embed_norm, test_labels,train_embed_norm, train_labels, j)\n",
    "    print(f\"Recall@{j} on test embeddings is {recall_at_k_list[j-1]}\")\n",
    "recall_path = os.path.join(loading_dir,filename);\n",
    "np.save(recall_path,recall_at_k_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f2bb7cc6-b917-47fa-9d28-47323acf925d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ================================================================================\n",
    "# UNCOMMENT THIS PART TO LOAD AND PLOT Recall@k VERSUS k FOR BASELINE AND MONOCON\n",
    "# ================================================================================\n",
    "base_string = f\"temp_{args.temperature}_bs_{args.batch_size}_lr_encoder_{args.lr_resnet}_num_mlp_layers_{args.mlp_num_layers}\"\n",
    "monocon_base_dir = (f\"{base_string}\" f\"_lr_head_{args.lr_mlp}\")\n",
    "baseline_base_dir = (f\"{base_string}\" f\"_ablation\")\n",
    "monocon_dir = os.path.join(args.base_results_dir,monocon_base_dir)\n",
    "baseline_dir = os.path.join(args.base_results_dir,baseline_base_dir)\n",
    "monocon_filename  = f\"recall_at_k_vs_k.npy\"\n",
    "baseline_filename = f\"recall_at_k_vs_k_ablation.npy\"\n",
    "monocon_path = os.path.join(monocon_dir,monocon_filename)\n",
    "baseline_path = os.path.join(baseline_dir,baseline_filename)\n",
    "\n",
    "monocon_recall_at_k = np.load(monocon_path)\n",
    "baseline_recall_at_k = np.load(baseline_path)\n",
    "print(len(monocon_recall_at_k))\n",
    "print(len(baseline_recall_at_k))\n",
    "k_list = list(range(1,len(monocon_recall_at_k)+1))\n",
    "print(len(k_list))\n",
    "\n",
    "# ================================================================================\n",
    "# CREATE PLOT\n",
    "# ================================================================================\n",
    "plt.figure()\n",
    "plt.plot(k_list, baseline_recall_at_k, marker='o', linestyle='-', label='Baseline Recall@k')\n",
    "plt.plot(k_list, monocon_recall_at_k, marker='x', linestyle='--', label='MonoCon Recall@k')\n",
    "\n",
    "# Add labels and title for clarity\n",
    "plt.xlabel('k',fontsize=16)\n",
    "plt.ylabel('Recall@k',fontsize=16)\n",
    "#plt.title('Recall@k vs. k for Baseline and MonoCon')\n",
    "\n",
    "# Add a legend to distinguish the lines\n",
    "plt.legend(loc='lower right',fontsize=14)\n",
    "plt.grid(False)\n",
    "plt.savefig('recall_at_k_vs_k_plot_CIFAR-100.png')\n",
    "plt.show()\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fb90a413-7c5e-4677-b7cd-3b3c044924b5",
   "metadata": {},
   "outputs": [],
   "source": [
    "# =============================================================================\n",
    "# COMPUTE Recall@1 and Recall@5\n",
    "# =============================================================================\n",
    "final_recall_at_1 = calculate_recall_at_k(test_embed_norm, test_labels,train_embed_norm, train_labels, 1)\n",
    "print(f\"Final Recall@k value for k = {1} on test embeddings is {final_recall_at_1}\")\n",
    "\n",
    "final_recall_at_5 = calculate_recall_at_k(test_embed_norm, test_labels,train_embed_norm, train_labels, 5)\n",
    "print(f\"Final Recall@k value for k = {5} on test embeddings is {final_recall_at_5}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "10871fd4-4a43-4553-bd23-5cb7e14c8f9d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# =============================================================================\n",
    "# COMPUTE k-NN classification accuracy using normalized \n",
    "# =============================================================================\n",
    "final_kNN_accuracy = evaluate_embeddings(train_embed_norm, train_labels, test_embed_norm, test_labels, args.knn_k)\n",
    "print(f\"Final k-NN classification accuracy for k = {args.knn_k} on test embeddings is {final_kNN_accuracy}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "329e55e5-0f9e-48d8-96af-0d583d88ccf0",
   "metadata": {},
   "outputs": [],
   "source": [
    "# =============================================================================\n",
    "# COMPUTE EFFECTIVE DIMENSIONALITY USING PCA ON TRAIN EMBEDDINGS\n",
    "# =============================================================================\n",
    "# The effective dimensionality is defined as the number of PCA components required to explain 99% variance of the train dataset\n",
    "\n",
    "# --- Fit a full PCA on the training data to determine variance ratios ---\n",
    "print(\"\\n--- Fitting PCA on training data to analyze variance ---\")\n",
    "pca_full = PCA(n_components=None, random_state=42)\n",
    "pca_full.fit(train_embed_norm)\n",
    "\n",
    "cumulative_variance = np.cumsum(pca_full.explained_variance_ratio_)\n",
    "n_for_99_variance = np.searchsorted(cumulative_variance, 0.99) + 1\n",
    "print(f\"Number of components to explain 99% of TRAIN variance: {n_for_99_variance}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9538ac0e-009d-4dff-aa02-54be4f3ce103",
   "metadata": {},
   "outputs": [],
   "source": [
    "# =============================================================================\n",
    "# USE THIS BLOCK ONLY FOR COMPUTING METRICS FOR COMPRESSED REPRESENTATIONS\n",
    "# =============================================================================\n",
    "\n",
    "# =============================================================================\n",
    "# TRUNCATE EMBEDDINGS AND EVALUATE ACROSS DIMENSIONS\n",
    "# =============================================================================\n",
    "\n",
    "# Convert initial tensors to numpy for PCA and k-NN\n",
    "train_embed_np = train_embed_norm.numpy()\n",
    "train_labels_np = train_labels.numpy()\n",
    "test_embed_np = test_embed_norm.numpy()\n",
    "\n",
    "# --- Loop through, create a specific PCA for each dimension, and evaluate ---\n",
    "truncation_dims = [128, 64, 16]\n",
    "k_for_knn = 5\n",
    "\n",
    "for dim in truncation_dims:\n",
    "    print(f\"\\n-- Evaluating for top {dim} PCA components --\")\n",
    "    \n",
    "    # 1. Define and fit a new PCA model for the current dimension\n",
    "    pca_reconstruction = PCA(n_components=dim, random_state=42)\n",
    "    pca_reconstruction.fit(train_embed_np) # Fit on NumPy training data\n",
    "    \n",
    "    # 2. Calculate Reconstruction Error \n",
    "    # Reconstruct the test data by transforming and inverse_transforming\n",
    "    reconstructed_test_np = pca_reconstruction.inverse_transform(pca_reconstruction.transform(test_embed_np))\n",
    "    \n",
    "    # Convert back to PyTorch Tensors for RMSE calculation\n",
    "    original_test_torch = torch.from_numpy(test_embed_np)\n",
    "    reconstructed_test_torch = torch.from_numpy(reconstructed_test_np)\n",
    "    \n",
    "    # Implement the RMSE formula\n",
    "    reconstruction_rmse = torch.sqrt(torch.mean(torch.square(original_test_torch - reconstructed_test_torch)))\n",
    "    print(f\"PCA Reconstruction RMSE: {reconstruction_rmse.item():.8f}\")\n",
    "\n",
    "    # 3. Get the truncated embeddings for downstream tasks\n",
    "    train_embed_truncated = pca_reconstruction.transform(train_embed_np)\n",
    "    test_embed_truncated = pca_reconstruction.transform(test_embed_np)\n",
    "    \n",
    "    # 4. Evaluate k-NN Classification Accuracy (uses NumPy arrays)\n",
    "    evaluate_embeddings(train_embed_truncated, train_labels_np, test_embed_truncated, test_labels.numpy(), k=k_for_knn)\n",
    "    \n",
    "    # 5. Evaluate Recall@k (uses PyTorch Tensors)\n",
    "    # Convert the truncated embeddings back to Tensors\n",
    "    train_embed_truncated_torch = torch.from_numpy(train_embed_truncated)\n",
    "    test_embed_truncated_torch = torch.from_numpy(test_embed_truncated)\n",
    "    \n",
    "    # Use TEST set as QUERY and TRAIN set as GALLERY\n",
    "    recall_at_1 = calculate_recall_at_k(test_embed_truncated_torch, test_labels, train_embed_truncated_torch, train_labels, k=1)\n",
    "    recall_at_5 = calculate_recall_at_k(test_embed_truncated_torch, test_labels, train_embed_truncated_torch, train_labels, k=5)\n",
    "    \n",
    "    print(f\"Recall@1: {recall_at_1 * 100:.2f}%\")\n",
    "    print(f\"Recall@5: {recall_at_5 * 100:.2f}%\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4efc23ee-c7e7-4877-aad4-fc80038de654",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ===============================================================================\n",
    "# QUANTIFY ROBUSTNESS USING PCA-BASED RMS RECONSTRUCTION ERROR ON TEST EMBEDDINGS\n",
    "# ===============================================================================\n",
    "\n",
    "# 1. Define a new PCA model with number of components equal to effective dimensionality defined in the previous cell\n",
    "pca_reconstruction = PCA(n_components=n_for_99_variance, random_state=42)\n",
    "\n",
    "# 2. Fit it on the training data\n",
    "pca_reconstruction.fit(train_embed_norm)\n",
    "\n",
    "# 3. Reconstruct the test data by transforming and inverse_transforming\n",
    "test_reconstructed = pca_reconstruction.inverse_transform(pca_reconstruction.transform(test_embed_norm))\n",
    "\n",
    "# 4. Calculate the Root Mean Squared Error between original and reconstructed test data\n",
    "reconstruction_rmse = torch.sqrt(torch.mean(torch.square(test_embed_norm - test_reconstructed)))\n",
    "print(f\"PCA-based Root Mean Squared Reconstruction Error on test embeddings is: {reconstruction_rmse:.8f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2a25b08c-3047-4ff6-adce-79ac758e032e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# =============================================================================\n",
    "# COMPUTE FEATURE CORRELATION MATRIX, AND VISUALIZE AND STORE ITS CLUSTERMAP\n",
    "# =============================================================================\n",
    "\n",
    "print(f\"\"\" Generating feature covariance matrix computed using normalized output features for a fixed random subset of training data\"\"\")\n",
    "\n",
    "feat_corr_matrix = np.corrcoef(train_subset_embed_norm.to('cpu').numpy(), rowvar=False)\n",
    "\n",
    "# Generate Correlation Clustermap\n",
    "\n",
    "print(f\"\"\" Generating clustermap of feature covariance matrix\"\"\")\n",
    "\n",
    "path_corr = os.path.join(loading_dir, f\"best_model_norm_output_feature_covariance_matrix.png\")\n",
    "\n",
    "g = sns.clustermap(\n",
    "    feat_corr_matrix,\n",
    "    cmap='RdBu_r',  # Red-Blue diverging colormap\n",
    "    vmin=-1,\n",
    "    vmax=1,\n",
    "    figsize=(10, 10),\n",
    "    cbar_pos=(0.02, 0.8, 0.03, 0.15) # Position colorbar\n",
    ")\n",
    "\n",
    "g.ax_heatmap.set_xlabel('')\n",
    "g.ax_heatmap.set_ylabel('')\n",
    "g.fig.suptitle('MonoCon', fontsize=50, y=1.06)\n",
    "g.savefig(path_corr)\n",
    "\n",
    "print(f\"Plot saved to {path_corr}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2541ff4b-f66b-4704-979d-866c31d1d176",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ===================================================================================\n",
    "# COMPUTE ROOT MEAN SQUARE OFF-DIAGONAL CORRELATION IN THE FEATURE CORRELATION MATRIX\n",
    "# ===================================================================================\n",
    "embed_dim = feat_corr_matrix.shape[0]\n",
    "mask = 1.0 - np.eye(embed_dim)\n",
    "rms_off_diagonal_feat_corr = np.sqrt(np.sum(np.square(feat_corr_matrix*mask))/(embed_dim*(embed_dim-1)))\n",
    "print(f\"RMS Off diagonal correlation in the feature correlation matrix is: {rms_off_diagonal_feat_corr:.4f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "108e99eb-1675-4064-ac31-479dd670a6d7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ===================================================================================\n",
    "# SAVE PERFORMANCE METRICS TO A JSON FILE\n",
    "# ===================================================================================\n",
    "\n",
    "import json\n",
    "\n",
    "performance_metrics = {\n",
    "    \"kNN accuracy for k=5 (%)\": float(final_kNN_accuracy*100),\n",
    "    \"Recall@1 (%)\": float(final_recall_at_1*100),\n",
    "    \"Recall@5 (%)\": float(final_recall_at_5*100),\n",
    "    \"number of PCA components for 99% variance\": int(n_for_99_variance),\n",
    "    \"PCA-based RMS reconstruction error\": float(reconstruction_rmse),\n",
    "    \"RMS off-diagonal feature correlation\": float(rms_off_diagonal_feat_corr)\n",
    "}\n",
    "\n",
    "json_file_path = os.path.join(loading_dir, 'best_model_performance_metrics.json')\n",
    "with open(json_file_path, \"w\") as f:\n",
    "    json.dump(performance_metrics, f, indent=4)\n",
    "\n",
    "print(f\"Performance metrics saved to {json_file_path}\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
