{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "edfe1206-66cd-4976-952c-b3b105f4563e",
   "metadata": {},
   "source": [
    "# Preprocessing"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d5e45906-72db-4b2c-953d-ed169b94c909",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "### Loading"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "78183f72-5e72-4b3a-af21-13a98237952e",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from skimage import filters, color\n",
    "from tqdm import tqdm\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import rsatoolbox\n",
    "from rsatoolbox.data import Dataset\n",
    "from rsatoolbox.rdm import calc_rdm\n",
    "from scipy.spatial.distance import cosine\n",
    "from scipy.stats import spearmanr, t\n",
    "from scipy.cluster.hierarchy import linkage, dendrogram\n",
    "from scipy.spatial.distance import pdist, squareform\n",
    "from cliffs_delta import cliffs_delta\n",
    "from sklearn.preprocessing import normalize\n",
    "from bisect import bisect_right, bisect_left\n",
    "from scipy.spatial.distance import pdist"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6ecf3f34-e4ef-49f8-a401-5e90b74a0a1e",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "### Constants"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9501201d-9595-4882-bd2b-2f26088ee404",
   "metadata": {},
   "outputs": [],
   "source": [
    "models = ['uni2', 'virchow2', 'prov', 'conch', 'plip', 'keep', 'dinov2']\n",
    "model_tags = ['UNI2', 'Virchow2', 'Prov-Gigapath', 'CONCH', 'PLIP', 'KEEP', 'ViT-Dinov2']\n",
    "cancer_types = ['BRCA', 'COAD', 'LUAD', 'LUSC']\n",
    "\n",
    "n_batches = 5\n",
    "total_slides = 250\n",
    "total_patches = 250\n",
    "\n",
    "num_slides_per_batch = total_slides // n_batches\n",
    "num_patches_per_batch = total_patches // n_batches\n"
   ]
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "PROJECT_SAVE_DIR = '/lotterlab/users/vmishra/RSA_08282025/'\n",
    "\n",
    "orig_embeddings_path = f\"{PROJECT_SAVE_DIR}/embeddings/\"\n",
    "batched_embeddings_path = f\"{PROJECT_SAVE_DIR}/embeddings-batched/\"\n",
    "\n"
   ],
   "id": "dc635418b8faca29"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "normalized = False\n",
    "norm_tag = '-normalized' if normalized else ''\n",
    "\n",
    "plot_path = f\"{PROJECT_SAVE_DIR}/plots{norm_tag}/\"\n",
    "rdm_path = f\"{PROJECT_SAVE_DIR}/rdms{norm_tag}/\""
   ],
   "id": "648d84955618873e"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "for d in [batched_embeddings_path, rdm_path, plot_path]:\n",
    "    os.makedirs(d, exist_ok=True)"
   ],
   "id": "6f1f409a9c1ddcc7"
  },
  {
   "cell_type": "markdown",
   "id": "8acdaf5a-08b4-4fa0-b7b4-5e3a6987a56f",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "### Splitting Into 5 Batches of 50 slides/50 patches each for validation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "96690212-a29e-41d7-852c-82bf69410f59",
   "metadata": {},
   "outputs": [],
   "source": [
    "for cancer_type in cancer_types:\n",
    "    for model in models:\n",
    "        file_path = os.path.join(orig_embeddings_path, f\"embeddings_{cancer_type}{norm_tag}-{model}.npy\")\n",
    "        print(f\"Processing {file_path}...\")\n",
    "\n",
    "        embeddings = np.load(file_path)\n",
    "        num_total = total_slides * total_patches\n",
    "        embedding_dim = embeddings.shape[1]\n",
    "        assert embeddings.shape[0] == num_total, f\"Unexpected shape for {file_path}\"\n",
    "\n",
    "        embeddings = embeddings.reshape(total_slides, total_patches, embedding_dim)\n",
    "\n",
    "        for batch_idx in range(n_batches):\n",
    "            start_slide = batch_idx * num_slides_per_batch\n",
    "            end_slide = (batch_idx + 1) * num_slides_per_batch\n",
    "\n",
    "            batch = embeddings[start_slide:end_slide, :num_patches_per_batch, :]\n",
    "            batch = batch.reshape(-1, embedding_dim)\n",
    "\n",
    "            batch_filename = f\"embeddings_{cancer_type}{norm_tag}-{model}-batch{batch_idx}.npy\"\n",
    "            batch_path = os.path.join(batched_embeddings_path, batch_filename)\n",
    "            np.save(batch_path, batch)\n",
    "            print(f\"Saved batch {batch_idx} to {batch_path}\")"
   ]
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "# RDMS",
   "id": "cbd1f79f1ab6bccc"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f39a8674-c7ea-4f25-8761-9139d79271ab",
   "metadata": {},
   "outputs": [],
   "source": [
    "for model in models:\n",
    "    for batch_idx in range(n_batches):\n",
    "        print(f\"Processing RDM for model: {model}, batch: {batch_idx}\")\n",
    "\n",
    "        embeddings_list = []\n",
    "        labels = []\n",
    "        for cancer_type in cancer_types:\n",
    "            file_path = os.path.join(batched_embeddings_path, f\"embeddings_{cancer_type}{norm_tag}-{model}-batch{batch_idx}.npy\")\n",
    "            emb = np.load(file_path)\n",
    "            embeddings_list.append(emb)\n",
    "            labels.extend([cancer_type] * len(emb))\n",
    "\n",
    "        embeddings = np.concatenate(embeddings_list, axis=0)\n",
    "\n",
    "        dataset = Dataset(measurements=embeddings, obs_descriptors={'disease': labels})\n",
    "        rdm = calc_rdm(dataset, method='euclidean')\n",
    "        rdm_matrix = rdm.get_matrices()[0]\n",
    "\n",
    "        save_path = os.path.join(rdm_path, f\"rdm_matrix_{model}{norm_tag}_batch{batch_idx}.npy\")\n",
    "        np.save(save_path, rdm_matrix)\n",
    "        print(f\"Saved RDM to {save_path}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7647284a-7e88-4539-88e9-c62ea68450fa",
   "metadata": {},
   "outputs": [],
   "source": [
    "def make_plot(rdm, model_name):\n",
    "    n = rdm.shape[0] / 4\n",
    "    divider_positions = [n, 2*n, 3*n]\n",
    "\n",
    "    plt.figure(figsize=(10, 8))\n",
    "    ax = sns.heatmap(rdm, cmap='Blues', annot=False, cbar=True)\n",
    "    ax.collections[0].colorbar.set_label(\"Normalized Distance\", fontsize=12)\n",
    "    plt.title(model_name, fontweight='bold', fontsize=20)\n",
    "    plt.xticks([n // 2, n + n // 2, n + n + n // 2, n + n + n + n // 2],\n",
    "               cancer_types, rotation=0, ha='center', fontsize=12, fontweight='bold')\n",
    "    plt.yticks([n // 2, n + n // 2, n + n + n // 2, n + n + n + n // 2],\n",
    "               cancer_types, rotation=0, fontsize=12, fontweight='bold')\n",
    "\n",
    "    # Add divider lines between disease types\n",
    "    for pos in divider_positions:\n",
    "        ax.axhline(pos, color='black', linewidth=1.5)\n",
    "        ax.axvline(pos, color='black', linewidth=1.5)\n",
    "\n",
    "    plt.savefig(plot_path + \"rdm_matrix\" + model_name + norm_tag + \".png\", dpi=300, bbox_inches='tight')\n",
    "\n",
    "for m, t in zip(models, model_tags):\n",
    "    print(m)\n",
    "    rdm_mat = np.load(os.path.join(rdm_path, f\"rdm_matrix_{model}{norm_tag}_batch0.npy\")) # just plot first batch\n",
    "    rdm_mat = rdm_mat / rdm_mat.max()\n",
    "    make_plot(rdm_mat, t)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f9562685-6a78-4585-b71b-73e2c8aa685d",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "# Spearman and Cosine Similarity Heatmap"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f22fd174-12cd-4e3c-999c-3443404d4374",
   "metadata": {},
   "outputs": [],
   "source": [
    "n_models = len(models)\n",
    "\n",
    "cosine_mean = np.zeros((n_models, n_models))\n",
    "cosine_lower = np.zeros((n_models, n_models))\n",
    "cosine_upper = np.zeros((n_models, n_models))\n",
    "\n",
    "spearman_mean = np.zeros((n_models, n_models))\n",
    "spearman_lower = np.zeros((n_models, n_models))\n",
    "spearman_upper = np.zeros((n_models, n_models))\n",
    "\n",
    "def get_range(values):\n",
    "    return np.mean(values), np.min(values), np.max(values)\n",
    "\n",
    "for i, model_i in enumerate(models):\n",
    "    for j, model_j in enumerate(models):\n",
    "        cosine_vals = []\n",
    "        spearman_vals = []\n",
    "\n",
    "        for batch in range(n_batches):\n",
    "            path_i = os.path.join(rdm_path, f\"rdm_matrix_{model_i}{norm_tag}_batch{batch}.npy\")\n",
    "            path_j = os.path.join(rdm_path, f\"rdm_matrix_{model_j}{norm_tag}_batch{batch}.npy\")\n",
    "            rdm_i = np.load(path_i)\n",
    "            rdm_j = np.load(path_j)\n",
    "\n",
    "            tri_i = rdm_i[np.triu_indices_from(rdm_i, k=1)]\n",
    "            tri_j = rdm_j[np.triu_indices_from(rdm_j, k=1)]\n",
    "\n",
    "            cos_sim = 1 - cosine(tri_i, tri_j)\n",
    "            spearman_corr, _ = spearmanr(tri_i, tri_j)\n",
    "\n",
    "            cosine_vals.append(cos_sim)\n",
    "            spearman_vals.append(spearman_corr)\n",
    "\n",
    "        # Cosine\n",
    "        mean_cos, cos_low, cos_high = get_range(cosine_vals)\n",
    "        cosine_mean[i, j] = mean_cos\n",
    "        cosine_lower[i, j] = cos_low\n",
    "        cosine_upper[i, j] = cos_high\n",
    "\n",
    "        # Spearman\n",
    "        mean_spear, spear_low, spear_high = get_range(spearman_vals)\n",
    "        spearman_mean[i, j] = mean_spear\n",
    "        spearman_lower[i, j] = spear_low\n",
    "        spearman_upper[i, j] = spear_high\n",
    "\n",
    "def format_annot(mean_mat, lower_mat, upper_mat):\n",
    "    n = mean_mat.shape[0]\n",
    "    annot = np.empty((n, n), dtype=object)\n",
    "    for i in range(n):\n",
    "        for j in range(n):\n",
    "            m = mean_mat[i, j]\n",
    "            l = lower_mat[i, j]\n",
    "            u = upper_mat[i, j]\n",
    "            annot[i, j] = f\"{m:.3f}\\n[{l:.3f}–{u:.3f}]\"\n",
    "    return annot\n",
    "\n",
    "cosine_annot = format_annot(cosine_mean, cosine_lower, cosine_upper)\n",
    "spearman_annot = format_annot(spearman_mean, spearman_lower, spearman_upper)\n",
    "\n",
    "# Plot cosine similarity\n",
    "plt.figure(figsize=(10, 8))\n",
    "sns.heatmap(cosine_mean, annot=cosine_annot, fmt=\"\", xticklabels=model_tags, yticklabels=model_tags,\n",
    "            cmap=\"Reds\", vmin=0.85, vmax=1, cbar_kws={\"label\": \"Cosine Similarity\"})\n",
    "plt.title(\"Cosine Similarity Between Model RDMs\\nMean [Range Across 5 Batches]\")\n",
    "plt.tight_layout()\n",
    "plt.savefig(plot_path + f\"cosineHeatmap{norm_tag}.pdf\", format=\"pdf\", bbox_inches=\"tight\")\n",
    "plt.show()\n",
    "\n",
    "# Plot Spearman correlation\n",
    "plt.figure(figsize=(10, 8))\n",
    "sns.heatmap(spearman_mean, annot=spearman_annot, fmt=\"\", xticklabels=model_tags, yticklabels=model_tags,\n",
    "            cmap=\"Reds\", vmin=0, vmax=1, cbar_kws={\"label\": \"Spearman Correlation\"})\n",
    "plt.title(\"Spearman Correlation Between Model RDMs\\nMean [Range Across 5 Batches]\")\n",
    "plt.tight_layout()\n",
    "plt.savefig(plot_path + f\"spearmanHeatmap{norm_tag}.pdf\", format=\"pdf\", bbox_inches=\"tight\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "74e16b76-c41b-4dff-bd8e-fe73fcca8b97",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "# Mean Spearman correlation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3caa6796-d9a7-4374-9eca-1f757a8ee782",
   "metadata": {},
   "outputs": [],
   "source": [
    "n_models = len(models)\n",
    "\n",
    "average_spearman_corr = (np.sum(spearman_mean, axis=1) - np.diag(spearman_mean)) / (n_models - 1)\n",
    "\n",
    "df_similarity = pd.DataFrame({\n",
    "    \"Model\": model_tags,\n",
    "    \"Average Spearman Correlation\": average_spearman_corr\n",
    "})\n",
    "\n",
    "df_similarity"
   ]
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "average_cos = (np.sum(cosine_mean, axis=1) - np.diag(cosine_mean)) / (n_models - 1)\n",
    "\n",
    "df_similarity = pd.DataFrame({\n",
    "    \"Model\": model_tags,\n",
    "    \"Average Cosine Correlation\": average_cos\n",
    "})\n",
    "\n",
    "df_similarity"
   ],
   "id": "ec2d9696d5e873e7"
  },
  {
   "cell_type": "markdown",
   "id": "d7169f1f-04f5-4160-8052-6fcf2486502e",
   "metadata": {},
   "source": [
    "# Dendrograms (Change when above values are run)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2e398795-997e-4f95-9ce1-646b7e1bc1f1",
   "metadata": {},
   "outputs": [],
   "source": [
    "n_models = len(models)\n",
    "spearman_mean = np.ones((n_models, n_models))\n",
    "\n",
    "spearman_mean[0, 1:] = [0.279, 0.41, 0.4, 0.223, 0.195, 0.122]\n",
    "spearman_mean[1, 2:] = [0.466, 0.391, 0.43, 0.371, 0.136]\n",
    "spearman_mean[2, 3:] = [0.505, 0.54, 0.515, 0.181]\n",
    "spearman_mean[3, 4:] = [0.485, 0.384, 0.151]\n",
    "spearman_mean[4, 5:] = [0.734, 0.222]\n",
    "spearman_mean[5, 6] = 0.192\n",
    "\n",
    "for i in range(1, n_models):\n",
    "    for j in range(i):\n",
    "        spearman_mean[i, j] = spearman_mean[j, i]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "18e92d88-7e38-4fd6-9ce5-b6b07d6ab5c6",
   "metadata": {},
   "outputs": [],
   "source": [
    "spearman_mean"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b54f7b54-d08e-41e3-8d34-74ae1f2b341e",
   "metadata": {},
   "outputs": [],
   "source": [
    "spearman_distances = 1 - spearman_mean\n",
    "\n",
    "# Convert to condensed form\n",
    "spearman_condensed = squareform(spearman_distances, checks=False)\n",
    "\n",
    "# Hierarchical Clustering (Ward's method)\n",
    "linkage_spearman = linkage(spearman_condensed, method=\"ward\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "70d45558-1039-4894-8eff-166fa887bd8b",
   "metadata": {},
   "outputs": [],
   "source": [
    "thresholds = [0.1, 0.33, 0.55, 0.62, 0.70, 0.81, 1.00]\n",
    "\n",
    "# Get Seaborn color palette with one more color than thresholds\n",
    "colors = sns.color_palette(\"deep\", len(thresholds) + 1)\n",
    "colors = ['blue', 'orange', 'green', 'red', 'purple', 'brown', 'gray']\n",
    "\n",
    "# Custom link color function based on thresholds\n",
    "def link_color_func(link_id):\n",
    "    print(link_id)\n",
    "    #print(link_id)\n",
    "#     if link_id < linkage_spearman.shape[0]:\n",
    "#         print(link_id)\n",
    "    row = link_id - linkage_spearman.shape[0] - 1\n",
    "    dist = linkage_spearman[row, 2]  # height of the node\n",
    "    for i, t in enumerate(thresholds):\n",
    "        if dist < t:\n",
    "            return colors[i]\n",
    "    return colors[-1]\n",
    "#     else:\n",
    "#         return 'black'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7f68e884-b87e-42d9-953b-f7b284694aa9",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots()\n",
    "\n",
    "#plt.figure(figsize=(10, 5))\n",
    "ct = [1, ]\n",
    "dendrogram(linkage_spearman, labels=models, leaf_rotation=90, leaf_font_size=12, color_threshold=0.82) #, link_color_func=link_color_func)\n",
    "plt.title(\"Hierarchical Clustering of Model RDM Similarity\", fontweight='bold')\n",
    "#plt.xlabel(\"Models\")\n",
    "plt.ylabel(\"Distance\")\n",
    "plt.tight_layout()\n",
    "ax.spines['top'].set_visible(False)\n",
    "ax.spines['right'].set_visible(False)\n",
    "plt.ylabel('Ward Distance', fontsize=12) #, fontweight='bold')\n",
    "\n",
    "plt.savefig(plot_path + f'ClusteringSpearman{norm_tag}.pdf')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d2d45b38-630c-413e-9b18-f3a3fea7f7ba",
   "metadata": {},
   "outputs": [],
   "source": [
    "# just copied from plot since taking so long\n",
    "n_models = len(models)\n",
    "cosine_mean = np.ones((n_models, n_models))\n",
    "\n",
    "cosine_mean[0, 1:] = [0.939, 0.966, 0.946, 0.924, 0.922, 0.883]\n",
    "cosine_mean[1, 2:] = [0.964, 0.942, 0.936, 0.933, 0.879]\n",
    "cosine_mean[2, 3:] = [0.963, 0.957, 0.958, 0.905]\n",
    "cosine_mean[3, 4:] = [0.947, 0.937, 0.88]\n",
    "cosine_mean[4, 5:] = [0.972, 0.882]\n",
    "cosine_mean[5, 6] = 0.878\n",
    "\n",
    "for i in range(1, n_models):\n",
    "    for j in range(i):\n",
    "        cosine_mean[i, j] = cosine_mean[j, i]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1be24901-66f8-4108-a223-d9c896ce79b5",
   "metadata": {},
   "outputs": [],
   "source": [
    "cosine_distances = 1 - cosine_mean\n",
    "\n",
    "# Convert to condensed form\n",
    "cosine_condensed = squareform(cosine_distances, checks=False)\n",
    "\n",
    "# Hierarchical Clustering (Ward's method)\n",
    "linkage_cosine = linkage(cosine_condensed, method=\"ward\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "03a6b6d7-dfad-4b6c-8521-92a898439ce8",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots()\n",
    "\n",
    "#plt.figure(figsize=(10, 5))\n",
    "ct = [1, ]\n",
    "dendrogram(linkage_cosine, labels=models, leaf_rotation=90, leaf_font_size=12, color_threshold=0.14) #, link_color_func=link_color_func)\n",
    "plt.title(\"Hierarchical Clustering of Model RDM Similarity\", fontweight='bold')\n",
    "#plt.xlabel(\"Models\")\n",
    "plt.ylabel(\"Distance\")\n",
    "plt.tight_layout()\n",
    "ax.spines['top'].set_visible(False)\n",
    "ax.spines['right'].set_visible(False)\n",
    "plt.ylabel('Ward Distance', fontsize=12) #, fontweight='bold')\n",
    "\n",
    "plt.savefig(plot_path + f'ClusteringCosine{norm_tag}.pdf')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f176bdce-d5e0-45e4-86b5-06e87911ff44",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "# Slide Specificity"
   ]
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "def efficient_cliffs_delta(a, b):\n",
    "    a = np.sort(a)\n",
    "    b = np.sort(b)\n",
    "    m, n = len(a), len(b)\n",
    "    more = sum(bisect_right(b, x) for x in a)\n",
    "    less = sum(n - bisect_left(b, x) for x in a)\n",
    "    delta = (more - less) / (m * n)\n",
    "    return delta"
   ],
   "id": "c5d479e3b8f76fcc"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "11dfa239-1a34-461b-82cc-75770f83720d",
   "metadata": {},
   "outputs": [],
   "source": [
    "def calculate_distances_per_slide(embeddings, num_slides, patches_per_slide):\n",
    "    intra_distances = []\n",
    "    inter_distances = []\n",
    "\n",
    "    for slide_idx in range(num_slides):\n",
    "        slide_indices = np.arange(slide_idx * patches_per_slide, (slide_idx + 1) * patches_per_slide)\n",
    "        other_indices = np.setdiff1d(np.arange(embeddings.shape[0]), slide_indices)\n",
    "\n",
    "        slide_distances = pdist(embeddings[slide_indices], metric='euclidean')\n",
    "        intra_distances.extend(slide_distances)\n",
    "\n",
    "        for other_idx in other_indices:\n",
    "            inter_distances.extend(np.linalg.norm(embeddings[slide_indices] - embeddings[other_idx], axis=1))\n",
    "\n",
    "    return np.array(intra_distances), np.array(inter_distances)\n",
    "\n",
    "\n",
    "results = []\n",
    "for model in tqdm(models, desc=\"Processing models\"):\n",
    "    delta_vals = []\n",
    "\n",
    "    for batch_idx in range(n_batches):\n",
    "        embeddings_list = []\n",
    "        for cancer_type in cancer_types:\n",
    "            batch_filename = f\"embeddings_{cancer_type}{norm_tag}-{model}-batch{batch_idx}.npy\"\n",
    "            batch_path = os.path.join(batched_embeddings_path, batch_filename)\n",
    "            emb = np.load(batch_path)\n",
    "            embeddings_list.append(emb)\n",
    "\n",
    "        all_embeddings = np.vstack(embeddings_list)\n",
    "        intra, inter = calculate_distances_per_slide(all_embeddings, num_slides=num_slides_per_batch * len(cancer_types), patches_per_slide=num_patches_per_batch)\n",
    "\n",
    "        delta_vals.append(efficient_cliffs_delta(intra, inter))\n",
    "\n",
    "    delta_vals = np.array(delta_vals)\n",
    "\n",
    "    results.append({\n",
    "        \"model\": model,\n",
    "        \"cliffs_delta_mean\": delta_vals.mean(),\n",
    "        \"cliffs_delta_min\": delta_vals.min(),\n",
    "        \"cliffs_delta_max\": delta_vals.max()\n",
    "    })\n",
    "\n",
    "results_df = pd.DataFrame(results)\n",
    "results_df = results_df.sort_values(by=\"cliffs_delta_mean\", ascending=False).reset_index(drop=True)\n",
    "results_df.to_csv(plot_path + f\"slide_specificity{norm_tag}.csv\", index=False)\n",
    "print(results_df)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dc60a30f-83f5-43d8-8f88-12831f6ee8d1",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "# Disease Specificity"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7f22f1de-75c9-4df9-8964-41876863a80c",
   "metadata": {},
   "outputs": [],
   "source": [
    "results = []\n",
    "for model in tqdm(models, desc=\"Processing models\"):\n",
    "    delta_vals = []\n",
    "\n",
    "    for batch_idx in range(n_batches):\n",
    "        disease_embeddings = {}\n",
    "        for cancer_type in cancer_types:\n",
    "            batch_filename = f\"embeddings_{cancer_type}{norm_tag}-{model}-batch{batch_idx}.npy\"\n",
    "            batch_path = os.path.join(batched_embeddings_path, batch_filename)\n",
    "            emb = np.load(batch_path)\n",
    "            disease_embeddings[cancer_type] = emb\n",
    "\n",
    "        intra_distances = []\n",
    "        for cancer_type in cancer_types:\n",
    "            emb = disease_embeddings[cancer_type]            \n",
    "            for slide_i in range(num_slides_per_batch-1):\n",
    "                for slide_j in range(slide_i + 1, num_slides_per_batch):\n",
    "                    start_i = slide_i * num_patches_per_batch\n",
    "                    end_i = (slide_i + 1) * num_patches_per_batch\n",
    "                    patches_i = emb[start_i:end_i]\n",
    "                    \n",
    "                    start_j = slide_j * num_patches_per_batch\n",
    "                    end_j = (slide_j + 1) * num_patches_per_batch\n",
    "                    patches_j = emb[start_j:end_j]\n",
    "                    \n",
    "                    diffs = np.linalg.norm(patches_i[:, None, :] - patches_j[None, :, :], axis=2).flatten()\n",
    "                    intra_distances.extend(diffs)\n",
    "\n",
    "        inter_distances = []\n",
    "        for i in range(len(cancer_types)-1):\n",
    "            for j in range(i + 1, len(cancer_types)):\n",
    "                emb1 = disease_embeddings[cancer_types[i]]\n",
    "                emb2 = disease_embeddings[cancer_types[j]]\n",
    "                diffs = np.linalg.norm(emb1[:, None, :] - emb2[None, :, :], axis=2).flatten()\n",
    "                inter_distances.extend(diffs)\n",
    "\n",
    "        intra = np.array(intra_distances)\n",
    "        inter = np.array(inter_distances)\n",
    "\n",
    "        delta_vals.append(efficient_cliffs_delta(intra, inter))\n",
    "\n",
    "    delta_vals = np.array(delta_vals)\n",
    "\n",
    "    results.append({\n",
    "        \"model\": model,\n",
    "        \"cliffs_delta_mean\": delta_vals.mean(),\n",
    "        \"cliffs_delta_min\": delta_vals.min(),\n",
    "        \"cliffs_delta_max\": delta_vals.max()\n",
    "    })\n",
    "\n",
    "results_df = pd.DataFrame(results)\n",
    "results_df = results_df.sort_values(by=\"cliffs_delta_mean\", ascending=False).reset_index(drop=True)\n",
    "results_df.to_csv(plot_path + f\"disease_specificity_exclude_same_slide{norm_tag}.csv\", index=False)\n",
    "print(results_df)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "42663596-1e5c-455d-b5be-37302c6c142d",
   "metadata": {},
   "source": [
    "# Spectral Analysis"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4d3ea52d-e1bc-4c83-9d90-b01dd54d9c26",
   "metadata": {},
   "outputs": [],
   "source": [
    "embeddings = {}\n",
    "for model in models:\n",
    "    embeddings_list = []\n",
    "    print(f\"\\nProcessing {model}...\")\n",
    "    for cancer_type in cancer_types:\n",
    "        file_path = os.path.join(orig_embeddings_path, f\"embeddings_{cancer_type}{norm_tag}-{model}.npy\")\n",
    "        emb = np.load(file_path)\n",
    "        embeddings_list.append(emb)\n",
    "    embeddings[model] = np.concatenate(embeddings_list, axis=0)\n",
    "    embeddings[model] -= embeddings[model].mean(axis=0)\n",
    "\n",
    "spectra = {}\n",
    "for model in models:\n",
    "    print(model)\n",
    "    U, S, Vt = np.linalg.svd(embeddings[model], full_matrices=False)\n",
    "    normalized_spectrum = S / S.sum()\n",
    "    spectra[model] = normalized_spectrum\n",
    "\n",
    "for model in models:\n",
    "    print(model, embeddings[model].shape)\n",
    "\n",
    "vl_models = ['keep', 'conch', 'plip']\n",
    "p_test = [0.25] + list(range(1, 101))\n",
    "s_sums = np.zeros((len(models), len(p_test)))\n",
    "\n",
    "for i, m in enumerate(models):\n",
    "    n_feat = embeddings[m].shape[-1]\n",
    "    for j, p in enumerate(p_test):\n",
    "        cut_off = int(np.round(n_feat * p/100))\n",
    "        s_sums[i, j] = spectra[m][:(cut_off+1)].sum()\n",
    "\n",
    "fig, ax = plt.subplots()\n",
    "for i, m in enumerate(models):\n",
    "    line_style = '--' if m in vl_models else '-'\n",
    "    plt.plot(p_test, s_sums[i], label=model_tags[i], linewidth=2, linestyle=line_style)\n",
    "\n",
    "plt.legend(prop={'weight': 'bold', 'size': 11})\n",
    "ax.set_ylim([0, 1.02])\n",
    "ax.set_aspect(100, adjustable='box')\n",
    "plt.xlabel('Percentage of Features', fontsize=12, fontweight='bold')\n",
    "plt.ylabel('Singular Value Cumulative Sum', fontsize=12, fontweight='bold')\n",
    "#plt.title('SVD Spectral Analysis')\n",
    "\n",
    "# Remove top and right spines\n",
    "ax.spines['top'].set_visible(False)\n",
    "ax.spines['right'].set_visible(False)\n",
    "\n",
    "plt.savefig(plot_path + f'spectral_analysis{norm_tag}.pdf', dpi=300, bbox_inches='tight', pad_inches=0.05)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
