{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "b70c22f6-a938-49d8-aa5d-a83e4ec9c81c",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import json\n",
    "import re\n",
    "from collections import defaultdict\n",
    "from constants import BASE_PATH_RESULTS, ds_info_file\n",
    "from helper import load_similarity_matrices, load_sim_matrix, load_ds_info\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "\n",
    "from sklearn.cluster import KMeans, AgglomerativeClustering, DBSCAN\n",
    "from sklearn.decomposition import PCA\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "from sklearn.metrics import silhouette_score, silhouette_samples\n",
    "from scipy.cluster.hierarchy import dendrogram, linkage\n",
    "from scipy.spatial.distance import pdist, squareform"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "ef6b605f-277a-4470-b6f4-13f688c3fe4f",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.rcParams.update({\n",
    "    \"font.size\": 8,          # base font size\n",
    "    \"axes.titlesize\": 10,     # subplot titles\n",
    "    \"axes.labelsize\": 8,     # x/y labels\n",
    "    \"xtick.labelsize\": 8,     # x tick labels\n",
    "    \"ytick.labelsize\": 8,     # y tick labels\n",
    "    \"legend.fontsize\": 8,     # legend\n",
    "})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "b9f4b1ac-1484-4b05-be40-8791b2143df4",
   "metadata": {},
   "outputs": [],
   "source": [
    "storing_path  = BASE_PATH_RESULTS / \"plots/similarity_matrices\"\n",
    "\n",
    "SAVE = True\n",
    "\n",
    "if SAVE:\n",
    "    storing_path.mkdir(parents=True, exist_ok=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "52c9b8bc-a06a-464a-895a-dffb5910d941",
   "metadata": {},
   "outputs": [],
   "source": [
    "base_path = BASE_PATH_RESULTS.parent / \"model_similarities_iclr_exp\"\n",
    "ds_list = [\n",
    "    'imagenet-subset-50k',\n",
    "    'wds_country211',\n",
    "    'wds_gtsrb',\n",
    "    'wds_vtab_caltech101',\n",
    "    'wds_vtab_diabetic_retinopathy',\n",
    "    'wds_vtab_eurosat',\n",
    "    'wds_vtab_pets',\n",
    "    'wds_fer2013',\n",
    "    'wds_stl10',\n",
    "    'wds_vtab_cifar10',\n",
    "    'wds_vtab_dmlab',\n",
    "    'wds_vtab_flowers',\n",
    "    'wds_vtab_resisc45',\n",
    "    'wds_cars',\n",
    "    'wds_fgvc_aircraft',\n",
    "    'wds_voc2007',\n",
    "    'wds_vtab_cifar100',\n",
    "    'wds_vtab_dtd',\n",
    "    'wds_vtab_pcam',\n",
    "    'wds_vtab_svhn',\n",
    "]\n",
    "sim_metrics = [\n",
    "    'cka_kernel_linear_unbiased',\n",
    "    'cka_kernel_rbf_unbiased_sigma_0.2'\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "ca99b3ea-fc23-4c53-8898-17fea85f669b",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>domain</th>\n",
       "      <th>name</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>wds_vtab_cifar10</th>\n",
       "      <td>Natural (multi-domain)</td>\n",
       "      <td>CIFAR-10</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>wds_vtab_cifar100</th>\n",
       "      <td>Natural (multi-domain)</td>\n",
       "      <td>CIFAR-100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>wds_vtab_caltech101</th>\n",
       "      <td>Natural (multi-domain)</td>\n",
       "      <td>Caltech-101</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>wds_country211</th>\n",
       "      <td>Natural (multi-domain)</td>\n",
       "      <td>Country-211</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>imagenet-subset-50k</th>\n",
       "      <td>Natural (multi-domain)</td>\n",
       "      <td>ImageNet-1k</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>wds_voc2007</th>\n",
       "      <td>Natural (multi-domain)</td>\n",
       "      <td>PASCAL VOC 2007</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>wds_stl10</th>\n",
       "      <td>Natural (multi-domain)</td>\n",
       "      <td>STL-10</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>wds_fgvc_aircraft</th>\n",
       "      <td>Natural (single-domain)</td>\n",
       "      <td>FGVC Aircraft</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>wds_vtab_flowers</th>\n",
       "      <td>Natural (single-domain)</td>\n",
       "      <td>Flowers</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>wds_gtsrb</th>\n",
       "      <td>Natural (single-domain)</td>\n",
       "      <td>GTSRB</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>wds_vtab_pets</th>\n",
       "      <td>Natural (single-domain)</td>\n",
       "      <td>Pets</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>wds_vtab_svhn</th>\n",
       "      <td>Natural (single-domain)</td>\n",
       "      <td>SVHN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>wds_cars</th>\n",
       "      <td>Natural (single-domain)</td>\n",
       "      <td>Stanford Cars</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>wds_vtab_diabetic_retinopathy</th>\n",
       "      <td>Specialized</td>\n",
       "      <td>Diabetic Retinopathy</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>wds_vtab_eurosat</th>\n",
       "      <td>Specialized</td>\n",
       "      <td>EuroSAT</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>wds_vtab_pcam</th>\n",
       "      <td>Specialized</td>\n",
       "      <td>PCAM</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>wds_vtab_resisc45</th>\n",
       "      <td>Specialized</td>\n",
       "      <td>RESISC45</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>wds_vtab_dtd</th>\n",
       "      <td>Structured</td>\n",
       "      <td>DTD</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>wds_vtab_dmlab</th>\n",
       "      <td>Structured</td>\n",
       "      <td>Dmlab</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>wds_fer2013</th>\n",
       "      <td>Structured</td>\n",
       "      <td>FER2013</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                                domain                  name\n",
       "wds_vtab_cifar10                Natural (multi-domain)              CIFAR-10\n",
       "wds_vtab_cifar100               Natural (multi-domain)             CIFAR-100\n",
       "wds_vtab_caltech101             Natural (multi-domain)           Caltech-101\n",
       "wds_country211                  Natural (multi-domain)           Country-211\n",
       "imagenet-subset-50k             Natural (multi-domain)           ImageNet-1k\n",
       "wds_voc2007                     Natural (multi-domain)       PASCAL VOC 2007\n",
       "wds_stl10                       Natural (multi-domain)                STL-10\n",
       "wds_fgvc_aircraft              Natural (single-domain)         FGVC Aircraft\n",
       "wds_vtab_flowers               Natural (single-domain)               Flowers\n",
       "wds_gtsrb                      Natural (single-domain)                 GTSRB\n",
       "wds_vtab_pets                  Natural (single-domain)                  Pets\n",
       "wds_vtab_svhn                  Natural (single-domain)                  SVHN\n",
       "wds_cars                       Natural (single-domain)         Stanford Cars\n",
       "wds_vtab_diabetic_retinopathy              Specialized  Diabetic Retinopathy\n",
       "wds_vtab_eurosat                           Specialized               EuroSAT\n",
       "wds_vtab_pcam                              Specialized                  PCAM\n",
       "wds_vtab_resisc45                          Specialized              RESISC45\n",
       "wds_vtab_dtd                                Structured                   DTD\n",
       "wds_vtab_dmlab                              Structured                 Dmlab\n",
       "wds_fer2013                                 Structured               FER2013"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ds_info = load_ds_info(ds_info_file)\n",
    "ds_info.loc['imagenet-subset-50k'] = ds_info.loc['imagenet-subset-10k'].copy()\n",
    "ds_info = ds_info.loc[ds_list]\n",
    "ds_info = ds_info.sort_values(['domain', 'name'])\n",
    "ds_list = ds_info.index.tolist()\n",
    "ds_info"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "0237845b-7959-47b9-94bb-fb7471a66a90",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 210 ms, sys: 53.8 ms, total: 263 ms\n",
      "Wall time: 837 ms\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "def sort_key(item):\n",
    "\n",
    "    if '_cls@' in item:\n",
    "        token_type = 0  \n",
    "    elif '_ap@' in item:\n",
    "        token_type = 1 \n",
    "    else:\n",
    "        token_type = 2\n",
    "    \n",
    "    match = re.search(r'\\.(\\d+)\\.', item)\n",
    "    if match:\n",
    "        block_num = int(match.group(1))\n",
    "    else:\n",
    "        block_num = np.inf \n",
    "    \n",
    "    return (token_type, block_num)\n",
    "\n",
    "res = dict()\n",
    "for path in base_path.rglob(\"similarity_matrix.pt\"):\n",
    "    ds = path.parts[5]\n",
    "    sim_metric = path.parts[6]\n",
    "    model = path.parts[7]\n",
    "    curr_sim_mat = load_sim_matrix(path.parent, allowed_models=None)\n",
    "    # sort rows and cols\n",
    "    sorted_list = sorted(curr_sim_mat.index.to_list(), key=sort_key)\n",
    "    curr_sim_mat = curr_sim_mat.loc[sorted_list, sorted_list]\n",
    "    # rename \n",
    "    np.fill_diagonal(curr_sim_mat.values, 1)\n",
    "    if sim_metric not in res.keys():\n",
    "        res[sim_metric] = {}\n",
    "    if ds not in res[sim_metric].keys(): \n",
    "        res[sim_metric][ds] = {}\n",
    "    res[sim_metric][ds][model] = curr_sim_mat"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "e4ca5179-313d-422d-94f7-b2d55f19f134",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cka_kernel_rbf_unbiased_sigma_0.2\n",
      "19 ['wds_cars', 'wds_country211', 'wds_fer2013', 'wds_fgvc_aircraft', 'wds_gtsrb', 'wds_stl10', 'wds_voc2007', 'wds_vtab_caltech101', 'wds_vtab_cifar10', 'wds_vtab_cifar100', 'wds_vtab_diabetic_retinopathy', 'wds_vtab_dmlab', 'wds_vtab_dtd', 'wds_vtab_eurosat', 'wds_vtab_flowers', 'wds_vtab_pcam', 'wds_vtab_pets', 'wds_vtab_resisc45', 'wds_vtab_svhn']\n",
      "cka_kernel_linear_unbiased\n",
      "20 ['imagenet-subset-50k', 'wds_cars', 'wds_country211', 'wds_fer2013', 'wds_fgvc_aircraft', 'wds_gtsrb', 'wds_stl10', 'wds_voc2007', 'wds_vtab_caltech101', 'wds_vtab_cifar10', 'wds_vtab_cifar100', 'wds_vtab_diabetic_retinopathy', 'wds_vtab_dmlab', 'wds_vtab_dtd', 'wds_vtab_eurosat', 'wds_vtab_flowers', 'wds_vtab_pcam', 'wds_vtab_pets', 'wds_vtab_resisc45', 'wds_vtab_svhn']\n"
     ]
    }
   ],
   "source": [
    "for res_k in res.keys():\n",
    "    print(res_k)\n",
    "    all_ds = sorted(res[res_k].keys())\n",
    "    print(len(all_ds), all_ds)\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "1cf4c039-302a-4581-9cbe-81ca3910635b",
   "metadata": {},
   "outputs": [],
   "source": [
    "name_mapping = {\n",
    "    \"OpenCLIP_ViT-B-16_openai_both\": \"CLIP-ViT-B\",\n",
    "    \"dinov2-vit-base-p14_both\": \"DINOV2-ViT-B\",\n",
    "    \"vit_base_patch16_224_both\": \"ViT-B\",\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "b19072b8-bd5a-41b9-bd6d-cf4d92b065a8",
   "metadata": {},
   "outputs": [],
   "source": [
    "def cluster_matrix_rows(matrix, method='kmeans', n_clusters=3, **kwargs):\n",
    "    if method == 'kmeans':\n",
    "        clusterer = KMeans(n_clusters=n_clusters, random_state=42, **kwargs)\n",
    "        cluster_labels = clusterer.fit_predict(matrix)\n",
    "        \n",
    "    else:\n",
    "        raise ValueError(\"No other method available at the moment\")\n",
    "        \n",
    "    return cluster_labels, clusterer\n",
    "\n",
    "\n",
    "def find_optimal_clusters(matrix, max_clusters=10, method='kmeans', save_path=None):\n",
    "    \"\"\"\n",
    "    Find optimal number of clusters using elbow method and silhouette score.\n",
    "    \"\"\"\n",
    "    n_clusters_range = range(2, min(max_clusters + 1, len(matrix)))\n",
    "    \n",
    "    if method == 'kmeans':\n",
    "        inertias = []\n",
    "        silhouette_scores = []\n",
    "        \n",
    "        for n_clusters in n_clusters_range:\n",
    "            kmeans = KMeans(n_clusters=n_clusters, random_state=42)\n",
    "            cluster_labels = kmeans.fit_predict(matrix)\n",
    "            \n",
    "            inertias.append(kmeans.inertia_)\n",
    "            silhouette_scores.append(silhouette_score(matrix, cluster_labels))\n",
    "        \n",
    "        fig, axes = plt.subplots(1, 2, figsize=(10, 4))\n",
    "        \n",
    "        # Elbow plot\n",
    "        axes[0].plot(n_clusters_range, inertias, 'bo-')\n",
    "        axes[0].set_xlabel('Number of Clusters')\n",
    "        axes[0].set_ylabel('Inertia')\n",
    "        axes[0].set_title('Elbow Method')\n",
    "        axes[0].grid(True)\n",
    "        \n",
    "        # Silhouette score plot\n",
    "        axes[1].plot(n_clusters_range, silhouette_scores, 'ro-')\n",
    "        axes[1].set_xlabel('Number of Clusters')\n",
    "        axes[1].set_ylabel('Silhouette Score')\n",
    "        axes[1].set_title('Silhouette Score')\n",
    "        axes[1].grid(True)\n",
    "        \n",
    "        top_two_cluster = np.argsort(silhouette_scores)[-2:]\n",
    "        optimal_n = np.array(n_clusters_range)[top_two_cluster]\n",
    "        for optim in optimal_n:\n",
    "            axes[1].axvline(optim, c='black', ls=\":\",alpha=0.75)\n",
    "\n",
    "        plt.tight_layout()\n",
    "        if save_path:\n",
    "            plt.savefig(save_path,\n",
    "                        bbox_inches='tight',\n",
    "                        edgecolor='none'\n",
    "                       )\n",
    "            print(f\"Elbo/ Silhouette visualization saved to: {save_path}. Closing figure ...\")\n",
    "            plt.close(fig)\n",
    "        else:\n",
    "            plt.show(fig)\n",
    "    \n",
    "        return optimal_n\n",
    "    \n",
    "    return None\n",
    "\n",
    "def pca_visualization(matrix, cluster_labels, row_names=None, save_path=None):\n",
    "    \"\"\"\n",
    "    Visualize clusters in PCA space showing first 4 components.\n",
    "    \"\"\"\n",
    "    # Apply PCA to get 4 components\n",
    "    pca = PCA(n_components=4)\n",
    "    matrix_pca = pca.fit_transform(matrix)\n",
    "    \n",
    "    # Create subplots for different PC combinations\n",
    "    fig, axes = plt.subplots(2, 3, figsize=(18, 12))\n",
    "    axes = axes.flatten()\n",
    "    \n",
    "    # Define PC combinations to plot\n",
    "    pc_combinations = [\n",
    "        (0, 1, 'PC1', 'PC2'),\n",
    "        (0, 2, 'PC1', 'PC3'), \n",
    "        (0, 3, 'PC1', 'PC4'),\n",
    "        (1, 2, 'PC2', 'PC3'),\n",
    "        (1, 3, 'PC2', 'PC4'),\n",
    "        (2, 3, 'PC3', 'PC4')\n",
    "    ]\n",
    "    \n",
    "    # Plot each cluster with different colors\n",
    "    unique_clusters = np.unique(cluster_labels)\n",
    "    colors = plt.cm.tab10(np.linspace(0, 1, len(unique_clusters)))\n",
    "    \n",
    "    for plot_idx, (pc_x, pc_y, pc_x_name, pc_y_name) in enumerate(pc_combinations):\n",
    "        ax = axes[plot_idx]\n",
    "        \n",
    "        for i, cluster in enumerate(unique_clusters):\n",
    "            mask = cluster_labels == cluster\n",
    "            ax.scatter(matrix_pca[mask, pc_x], matrix_pca[mask, pc_y], \n",
    "                      c=[colors[i]], label=f'Cluster {cluster}', \n",
    "                      alpha=0.7, s=100)\n",
    "            \n",
    "            if row_names is not None:\n",
    "                for j, (x, y) in enumerate(zip(matrix_pca[mask, pc_x], matrix_pca[mask, pc_y])):\n",
    "                    row_idx = np.where(mask)[0][j]\n",
    "                    ds_name = row_names[row_idx]\n",
    "                    ax.annotate(ds_info.loc[ds_name, 'name'], (x, y), \n",
    "                               xytext=(5, 5), textcoords='offset points',\n",
    "                               fontsize=8, alpha=0.8)\n",
    "        \n",
    "        ax.set_xlabel(f'{pc_x_name} ({pca.explained_variance_ratio_[pc_x]:.2%} variance)')\n",
    "        ax.set_ylabel(f'{pc_y_name} ({pca.explained_variance_ratio_[pc_y]:.2%} variance)')\n",
    "        ax.set_title(f'Clusters: {pc_x_name} vs {pc_y_name}')\n",
    "        ax.grid(True, alpha=0.3)\n",
    "    \n",
    "        if plot_idx == 0:\n",
    "            ax.legend()\n",
    "    \n",
    "    plt.suptitle('Clusters Visualized in PCA Space (Multiple PC Combinations)', y=1.0, fontsize=10)\n",
    "    plt.tight_layout()\n",
    "    if save_path:\n",
    "        plt.savefig(save_path, dpi=300, bbox_inches='tight', \n",
    "                   facecolor='white', edgecolor='none')\n",
    "        print(f\"PCA visualization saved to: {save_path}. Closing figure ...\")\n",
    "        plt.close(fig)\n",
    "    else:\n",
    "        plt.show(fig)\n",
    "    return matrix_pca, pca"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "a0b9375f-2539-434b-ad20-290c538eda17",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "OpenCLIP_ViT-B-16_openai_both\n",
      "19 (19, 325)\n",
      "Ensured curr_dir=PosixPath('/home/space/rep2rep/results_iclr_exp/plots/similarity_matrices/clustering/cka_kernel_rbf_unbiased_sigma_0.2_OpenCLIP_ViT-B-16_openai_both') exists\n",
      "Elbo/ Silhouette visualization saved to: /home/space/rep2rep/results_iclr_exp/plots/similarity_matrices/clustering/cka_kernel_rbf_unbiased_sigma_0.2_OpenCLIP_ViT-B-16_openai_both/find_cluster.pdf. Closing figure ...\n",
      "PCA visualization saved to: /home/space/rep2rep/results_iclr_exp/plots/similarity_matrices/clustering/cka_kernel_rbf_unbiased_sigma_0.2_OpenCLIP_ViT-B-16_openai_both/pca_clustering.pdf. Closing figure ...\n",
      "PCA visualization saved to: /home/space/rep2rep/results_iclr_exp/plots/similarity_matrices/clustering/cka_kernel_rbf_unbiased_sigma_0.2_OpenCLIP_ViT-B-16_openai_both/pca_clustering.pdf. Closing figure ...\n",
      "\n",
      "\n",
      "dinov2-vit-base-p14_both\n",
      "19 (19, 325)\n",
      "Ensured curr_dir=PosixPath('/home/space/rep2rep/results_iclr_exp/plots/similarity_matrices/clustering/cka_kernel_rbf_unbiased_sigma_0.2_dinov2-vit-base-p14_both') exists\n",
      "Elbo/ Silhouette visualization saved to: /home/space/rep2rep/results_iclr_exp/plots/similarity_matrices/clustering/cka_kernel_rbf_unbiased_sigma_0.2_dinov2-vit-base-p14_both/find_cluster.pdf. Closing figure ...\n",
      "PCA visualization saved to: /home/space/rep2rep/results_iclr_exp/plots/similarity_matrices/clustering/cka_kernel_rbf_unbiased_sigma_0.2_dinov2-vit-base-p14_both/pca_clustering.pdf. Closing figure ...\n",
      "PCA visualization saved to: /home/space/rep2rep/results_iclr_exp/plots/similarity_matrices/clustering/cka_kernel_rbf_unbiased_sigma_0.2_dinov2-vit-base-p14_both/pca_clustering.pdf. Closing figure ...\n",
      "\n",
      "\n",
      "vit_base_patch16_224_both\n",
      "19 (19, 325)\n",
      "Ensured curr_dir=PosixPath('/home/space/rep2rep/results_iclr_exp/plots/similarity_matrices/clustering/cka_kernel_rbf_unbiased_sigma_0.2_vit_base_patch16_224_both') exists\n",
      "Elbo/ Silhouette visualization saved to: /home/space/rep2rep/results_iclr_exp/plots/similarity_matrices/clustering/cka_kernel_rbf_unbiased_sigma_0.2_vit_base_patch16_224_both/find_cluster.pdf. Closing figure ...\n",
      "PCA visualization saved to: /home/space/rep2rep/results_iclr_exp/plots/similarity_matrices/clustering/cka_kernel_rbf_unbiased_sigma_0.2_vit_base_patch16_224_both/pca_clustering.pdf. Closing figure ...\n",
      "PCA visualization saved to: /home/space/rep2rep/results_iclr_exp/plots/similarity_matrices/clustering/cka_kernel_rbf_unbiased_sigma_0.2_vit_base_patch16_224_both/pca_clustering.pdf. Closing figure ...\n",
      "\n",
      "\n",
      "OpenCLIP_ViT-B-16_openai_both\n",
      "20 (20, 325)\n",
      "Ensured curr_dir=PosixPath('/home/space/rep2rep/results_iclr_exp/plots/similarity_matrices/clustering/cka_kernel_linear_unbiased_OpenCLIP_ViT-B-16_openai_both') exists\n",
      "Elbo/ Silhouette visualization saved to: /home/space/rep2rep/results_iclr_exp/plots/similarity_matrices/clustering/cka_kernel_linear_unbiased_OpenCLIP_ViT-B-16_openai_both/find_cluster.pdf. Closing figure ...\n",
      "PCA visualization saved to: /home/space/rep2rep/results_iclr_exp/plots/similarity_matrices/clustering/cka_kernel_linear_unbiased_OpenCLIP_ViT-B-16_openai_both/pca_clustering.pdf. Closing figure ...\n",
      "PCA visualization saved to: /home/space/rep2rep/results_iclr_exp/plots/similarity_matrices/clustering/cka_kernel_linear_unbiased_OpenCLIP_ViT-B-16_openai_both/pca_clustering.pdf. Closing figure ...\n",
      "\n",
      "\n",
      "dinov2-vit-base-p14_both\n",
      "19 (19, 325)\n",
      "Ensured curr_dir=PosixPath('/home/space/rep2rep/results_iclr_exp/plots/similarity_matrices/clustering/cka_kernel_linear_unbiased_dinov2-vit-base-p14_both') exists\n",
      "Elbo/ Silhouette visualization saved to: /home/space/rep2rep/results_iclr_exp/plots/similarity_matrices/clustering/cka_kernel_linear_unbiased_dinov2-vit-base-p14_both/find_cluster.pdf. Closing figure ...\n",
      "PCA visualization saved to: /home/space/rep2rep/results_iclr_exp/plots/similarity_matrices/clustering/cka_kernel_linear_unbiased_dinov2-vit-base-p14_both/pca_clustering.pdf. Closing figure ...\n",
      "PCA visualization saved to: /home/space/rep2rep/results_iclr_exp/plots/similarity_matrices/clustering/cka_kernel_linear_unbiased_dinov2-vit-base-p14_both/pca_clustering.pdf. Closing figure ...\n",
      "\n",
      "\n",
      "vit_base_patch16_224_both\n",
      "20 (20, 325)\n",
      "Ensured curr_dir=PosixPath('/home/space/rep2rep/results_iclr_exp/plots/similarity_matrices/clustering/cka_kernel_linear_unbiased_vit_base_patch16_224_both') exists\n",
      "Elbo/ Silhouette visualization saved to: /home/space/rep2rep/results_iclr_exp/plots/similarity_matrices/clustering/cka_kernel_linear_unbiased_vit_base_patch16_224_both/find_cluster.pdf. Closing figure ...\n",
      "PCA visualization saved to: /home/space/rep2rep/results_iclr_exp/plots/similarity_matrices/clustering/cka_kernel_linear_unbiased_vit_base_patch16_224_both/pca_clustering.pdf. Closing figure ...\n",
      "PCA visualization saved to: /home/space/rep2rep/results_iclr_exp/plots/similarity_matrices/clustering/cka_kernel_linear_unbiased_vit_base_patch16_224_both/pca_clustering.pdf. Closing figure ...\n",
      "\n",
      "\n"
     ]
    }
   ],
   "source": [
    "def flatten_sim(curr_sim_mat):\n",
    "    idx1, idx2 = np.triu_indices_from(curr_sim_mat, k=1)\n",
    "    return curr_sim_mat.to_numpy()[idx1, idx2]\n",
    "\n",
    "clustering_res = dict()\n",
    "for metric, metric_data in res.items():\n",
    "    clustering_res[metric] = {}\n",
    "    for mid in name_mapping.keys():\n",
    "        clustering_res[metric][mid] = {}\n",
    "        print(mid)\n",
    "        mid_data = {ds : ds_sim_mats[mid] for ds, ds_sim_mats in metric_data.items() if mid in ds_sim_mats}\n",
    "        mid_data = {ds : flatten_sim(curr_mat) for ds, curr_mat in mid_data.items()}\n",
    "\n",
    "        avail_ds = list(mid_data.keys())\n",
    "        mid_all_sims = np.stack(list(mid_data.values()))\n",
    "        print(len(avail_ds), mid_all_sims.shape)\n",
    "        \n",
    "        curr_dir = storing_path / \"clustering\" / f\"{metric}_{mid}\"\n",
    "        if SAVE:\n",
    "            curr_dir.mkdir(parents=True, exist_ok=True)\n",
    "            print(f\"Ensured {curr_dir=} exists\")\n",
    "        else:\n",
    "            curr_dir = None\n",
    "        \n",
    "        top_ks = find_optimal_clusters(\n",
    "            mid_all_sims, \n",
    "            max_clusters=len(avail_ds),\n",
    "            save_path=(curr_dir / \"find_cluster.pdf\") if curr_dir else None\n",
    "        )\n",
    "        cluster_res = {}\n",
    "        for k, optimal_k in enumerate(top_ks[::-1]):\n",
    "            cluster_labels, clusterer = cluster_matrix_rows(mid_all_sims, method='kmeans', n_clusters=optimal_k)\n",
    "            matrix_2d, pca = pca_visualization(\n",
    "                mid_all_sims, \n",
    "                cluster_labels, \n",
    "                avail_ds,\n",
    "                save_path=(curr_dir / \"pca_clustering.pdf\") if curr_dir else None\n",
    "            )\n",
    "            cluster_res[f\"top_{k+1}_cluster_lbl\"] = cluster_labels\n",
    "\n",
    "        cluster_res['ds'] = avail_ds\n",
    "        clustering_res[metric][mid] = pd.DataFrame(cluster_res).set_index(\"ds\")\n",
    "        print()\n",
    "        print()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "a7f5eb99-d430-4233-9fe7-e92e0efac1e4",
   "metadata": {},
   "outputs": [],
   "source": [
    "## merge everything in ds_info\n",
    "for metric, metric_data in clustering_res.items():\n",
    "    for mid, mid_data in metric_data.items():\n",
    "        cols = mid_data.columns.tolist()\n",
    "        new_cols = [f\"{metric}_{mid}_{col}\" for col in cols]\n",
    "        for old_col, new_col in zip(cols, new_cols):\n",
    "            ds_info.loc[mid_data.index, new_col] = mid_data.loc[mid_data.index, old_col]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "21f0a822-ee6f-4c47-a03c-2e5b4f2bce22",
   "metadata": {},
   "outputs": [],
   "source": [
    "if SAVE:\n",
    "    ds_info.to_csv(storing_path / \"clustering\" / \"ds_info_with_cluster_assignments.csv\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "c7ba5e22-410e-4e6a-b886-7bf2818bfed2",
   "metadata": {},
   "outputs": [],
   "source": [
    "# import textwrap\n",
    "\n",
    "# def plot_mean_std_heatmap(mean_mat, std_mat):\n",
    "#     fig, axes = plt.subplots(1, 2, figsize=(16, 6))\n",
    "    \n",
    "#     # Global min/max for consistent color scale\n",
    "#     vmin_mean, vmax_mean = mean_mat.min(), mean_mat.max()\n",
    "#     vmin_std, vmax_std = std_mat.min(), std_mat.max()\n",
    "    \n",
    "#     # Plot mean heatmap\n",
    "#     sns.heatmap(mean_mat, \n",
    "#                ax=axes[0],\n",
    "#                cmap='viridis',\n",
    "#                square=True,\n",
    "#                cbar=True,\n",
    "#                vmin=vmin_mean,\n",
    "#                vmax=vmax_mean,\n",
    "#                xticklabels=False,\n",
    "#                yticklabels=False)\n",
    "#     axes[0].set_title(f'Mean\\n{metric} | {mid} | {col}={group}', \n",
    "#                      fontsize=12, pad=10)\n",
    "    \n",
    "#     # Plot std heatmap\n",
    "#     sns.heatmap(std_mat, \n",
    "#                ax=axes[1],\n",
    "#                cmap='plasma',  # Different colormap for std\n",
    "#                square=True,\n",
    "#                cbar=True,\n",
    "#                vmin=vmin_std,\n",
    "#                vmax=vmax_std,\n",
    "#                xticklabels=False,\n",
    "#                yticklabels=False)\n",
    "#     axes[1].set_title(f'Std Dev\\n{metric} | {mid} | {col}={group}', \n",
    "#                      fontsize=12, pad=10)\n",
    "    \n",
    "#     wrapped_datasets = textwrap.fill(', '.join(grouped_ds), width=120)\n",
    "    \n",
    "#     plt.suptitle(f'Datasets (n={len(grouped_ds)}): {wrapped_datasets}', \n",
    "#                fontsize=10, y=0.98)\n",
    "    \n",
    "#     # Adjust layout\n",
    "#     plt.tight_layout()\n",
    "#     plt.subplots_adjust(top=0.85)  # Make room for suptitle\n",
    "#     plt.show(fig)\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "f4109b7c-4e91-47ce-8fde-b9ce78ebe302",
   "metadata": {},
   "outputs": [],
   "source": [
    "# cnt = 0 \n",
    "# for metric, metric_data in clustering_res.items():\n",
    "#     for mid, mid_data in metric_data.items():\n",
    "#         for col in mid_data.columns:\n",
    "#             for group, group_data in mid_data.groupby(col):\n",
    "#                 print(metric, mid, col, group) \n",
    "#                 grouped_ds = group_data.index.tolist()\n",
    "#                 mat = [res[metric][ds][mid] for ds in grouped_ds]\n",
    "#                 mat = np.stack(mat)\n",
    "#                 print(mat.shape)\n",
    "#                 mean_mat = np.mean(mat, axis=0)\n",
    "#                 std_mat = np.std(mat, axis=0)\n",
    "#                 print(mean_mat.shape, std_mat.shape)\n",
    "#                 print()\n",
    "#                 plot_mean_std_heatmap(mean_mat, std_mat)\n",
    "#                 cnt +=1\n",
    "#     break\n",
    "# print(cnt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "a81163d8-45f2-42d4-8563-e816bde9f56c",
   "metadata": {},
   "outputs": [],
   "source": [
    "def fmt_lbl(curr_lbl):\n",
    "    match = re.search(r'\\.(\\d+)\\.', curr_lbl)\n",
    "    if match:\n",
    "        suffix = f\"B_{int(match.group(1))}\"\n",
    "    else:\n",
    "        suffix = \"last\"\n",
    "    if \"_cls@\" in curr_lbl:\n",
    "        return f\"cls@{suffix}\"\n",
    "    elif \"_ap@\" in curr_lbl:\n",
    "        return f\"avg.pool@{suffix}\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "ed03de3e-ee41-40ba-9aea-2a0a6e95ab48",
   "metadata": {},
   "outputs": [],
   "source": [
    "# base_size = 2\n",
    "# vmin, vmax = 0, 1\n",
    "\n",
    "# ds_subset = ds_info.groupby('domain').sample(3, random_state=42)\n",
    "# ds_subset_list = ds_subset.index.tolist()\n",
    "\n",
    "# for metric_name, sim_mats in res.items():\n",
    "\n",
    "#     curr_ds_list = [ds for ds in ds_subset_list if ds in sim_mats.keys()]\n",
    "    \n",
    "#     n_ds = len(curr_ds_list)\n",
    "#     n_models = len(sim_mats[next(iter(sim_mats))])\n",
    "\n",
    "#     fig, all_axes = plt.subplots(n_ds, n_models, figsize=(base_size * n_models, base_size * n_ds), sharex=True, sharey=True)\n",
    "    \n",
    "#     for i, ds in enumerate(curr_ds_list):\n",
    "#         ds_sim_mats = sim_mats[ds]\n",
    "#         axes = all_axes[i,:]\n",
    "#         for idx, model_name in enumerate(sorted(ds_sim_mats.keys())):\n",
    "#             sim_matrix = ds_sim_mats[model_name]\n",
    "#             sns.heatmap(\n",
    "#                 sim_matrix,\n",
    "#                 ax=axes[idx],\n",
    "#                 cmap=\"viridis\",\n",
    "#                 square=True,\n",
    "#                 cbar=False,\n",
    "#                 vmin=vmin,\n",
    "#                 vmax=vmax,\n",
    "#                 xticklabels=False,\n",
    "#                 yticklabels=False,\n",
    "#             )\n",
    "#             axes[idx].axvline(len(sim_matrix) // 2, c=\"white\", alpha=0.5, ls=\"--\")\n",
    "#             axes[idx].axhline(len(sim_matrix) // 2, c=\"white\", alpha=0.5, ls=\"--\")\n",
    "#             if idx == 0:\n",
    "#                 axes[idx].set_ylabel(ds_info.loc[ds, 'name'] + \"\\n\" + ds_info.loc[ds, 'domain'], fontsize=10)\n",
    "#             if i == 0:\n",
    "#                 axes[idx].set_title(name_mapping[model_name], pad=10)\n",
    "#     fig.tight_layout()\n",
    "#     plt.show(fig)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "a43c08c3-9b1e-445c-bae3-cf6d177f7091",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "19 10 1 9\n",
      "All heatmaps saved to: /home/space/rep2rep/results_iclr_exp/plots/similarity_matrices/all_heatmaps_cka_kernel_rbf_unbiased_sigma_0.2.pdf. Closing figure ...\n",
      "All heatmaps saved to: /home/space/rep2rep/results_iclr_exp/plots/similarity_matrices/all_heatmaps_cka_kernel_linear_unbiased.pdf. Closing figure ...\n"
     ]
    }
   ],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import textwrap\n",
    "\n",
    "base_size = 2\n",
    "vmin, vmax = 0, 1\n",
    "max_cols = 10\n",
    "\n",
    "# ds_subset = ds_info.groupby('domain').sample(3, random_state=42)\n",
    "\n",
    "ds_subset_list = ds_info.index.tolist()\n",
    "\n",
    "for metric_name, sim_mats in res.items():\n",
    "    curr_ds_list = [ds for ds in ds_subset_list if ds in sim_mats.keys()]\n",
    "    \n",
    "    n_ds = len(curr_ds_list)\n",
    "    n_models = len(sim_mats[next(iter(sim_mats))])\n",
    "\n",
    "    ncols = min(max_cols, n_ds)\n",
    "    if ncols != n_ds:\n",
    "        rows_of_same_model = int(np.ceil(n_ds / max_cols))\n",
    "        nrow =  rows_of_same_model * n_models\n",
    "    else:\n",
    "        rows_of_same_model = 1\n",
    "        nrow = n_models\n",
    "\n",
    "    fig, all_axes = plt.subplots(nrow, \n",
    "                                 ncols, \n",
    "                                 figsize=(base_size * ncols, base_size * nrow), \n",
    "                                 sharex=True, \n",
    "                                 sharey=True)\n",
    "    \n",
    "    model_names = sorted(sim_mats[next(iter(sim_mats))].keys())\n",
    "\n",
    "    rows_with_title = []\n",
    "    \n",
    "    for i, model_name in enumerate(model_names): \n",
    "        for j, ds in enumerate(curr_ds_list):  # Loop over datasets (columns)\n",
    "            if ncols != n_ds:\n",
    "                row_in_model = j // max_cols\n",
    "                col = j % max_cols\n",
    "                row = i + n_models * row_in_model\n",
    "            else:\n",
    "                row = model_idx\n",
    "                col = ds_idx\n",
    "            \n",
    "            ax = all_axes[row, col]\n",
    "            \n",
    "            ds_sim_mats = sim_mats[ds]\n",
    "            try:\n",
    "                sim_matrix = ds_sim_mats[model_name]\n",
    "            except KeyError:\n",
    "                ax.axis(\"off\")\n",
    "                continue\n",
    "            \n",
    "            sns.heatmap(\n",
    "                sim_matrix,\n",
    "                ax=ax,\n",
    "                cmap=\"viridis\",\n",
    "                square=True,\n",
    "                cbar=False,\n",
    "                vmin=vmin,\n",
    "                vmax=vmax,\n",
    "                xticklabels=False,\n",
    "                yticklabels=False,\n",
    "            )\n",
    "            \n",
    "            ax.axvline(len(sim_matrix) // 2, c=\"white\", alpha=0.5, ls=\"--\")\n",
    "            ax.axhline(len(sim_matrix) // 2, c=\"white\", alpha=0.5, ls=\"--\")\n",
    "            \n",
    "            if i == 0:\n",
    "                ds_domain = \"\\n\".join(ds_info.loc[ds, 'domain'].split(\" \"))\n",
    "                ax.set_title(ds_info.loc[ds, 'name'] + \"\\n\" + ds_domain, pad=10)\n",
    "                # ax.set_title(ds_info.loc[ds, 'name'], pad=10)\n",
    "                if row>0:\n",
    "                    rows_with_title.append(row)\n",
    "            \n",
    "            if col == 0:\n",
    "                ax.set_ylabel(name_mapping[model_name], fontsize=10)\n",
    "    \n",
    "    if n_ds % ncols !=0:\n",
    "        not_filled = ncols - (n_ds % ncols )\n",
    "        print(n_ds, ncols, not_filled, ncols - not_filled)\n",
    "        \n",
    "        start_row =  (rows_of_same_model - 1) * n_models\n",
    "        for col in range((ncols - not_filled), ncols):\n",
    "            for k in range(start_row, nrow):\n",
    "                all_axes[k, col].axis(\"off\")\n",
    "    \n",
    "    # fig.tight_layout()\n",
    "    fig.subplots_adjust(\n",
    "        hspace=0.1,  # Height spacing between rows\n",
    "        wspace=0.1,  # Width spacing between columns\n",
    "    )\n",
    "    \n",
    "    pct_to_add = 0\n",
    "    const_add = 0.045\n",
    "    for row in range(min(rows_with_title), nrow):  # rows after 3\n",
    "        if row in rows_with_title:\n",
    "                pct_to_add += const_add\n",
    "        for col in range(ncols):\n",
    "            ax = all_axes[row, col]\n",
    "            pos = ax.get_position()\n",
    "            ax.set_position([pos.x0, pos.y0 - pct_to_add, pos.width, pos.height])\n",
    "            \n",
    "    if SAVE:\n",
    "        save_path = storing_path / f\"all_heatmaps_{metric_name}.pdf\"\n",
    "        plt.savefig(save_path, bbox_inches='tight', \n",
    "                   facecolor='white', edgecolor='none')\n",
    "        print(f\"All heatmaps saved to: {save_path}. Closing figure ...\")\n",
    "        plt.close(fig)\n",
    "    else:\n",
    "        plt.show(fig)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7b5a9706-3d1a-40ee-a02f-8f224f767738",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
