{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "70fe270a-c1d8-4a95-9405-4068e841d52e",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from scipy.stats import wasserstein_distance\n",
    "from itertools import combinations\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3bb26155-773a-47c9-9006-ac697192d998",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch import nn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1824d69d-5e8d-4724-a26b-d97b1ae4cd2a",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sims import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b50d8c02-f7e8-4445-900c-2a4605031c06",
   "metadata": {},
   "outputs": [],
   "source": [
    "#rtd=RTDLoss(dim=1, lp=1.0,engine=\"ripser\",is_sym=True)\n",
    "nsa_criterion = NSALoss()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1b44c479-f9f2-4868-b522-7ac77d286507",
   "metadata": {},
   "outputs": [],
   "source": [
    "models = [\"GCN\",\"SAGE\",\"GAT\",\"CGCN\"]\n",
    "dataset_name='Amazon'\n",
    "#dataset_name='Flickr'\n",
    "path = 'model_data/'+dataset_name+'/'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "67995ae2-dd2f-464d-8857-6e43ed7732d2",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "max_epochs={}\n",
    "for model_name in models:\n",
    "    files = os.listdir(path+model_name)\n",
    "    for fv_path in files:\n",
    "        #print(fv_path)\n",
    "        fv_path, ext = fv_path.split(\".\")\n",
    "        if ext != \"npz\":\n",
    "            continue\n",
    "        #print(fv_path)\n",
    "        task,run_id,epoch = fv_path.split(\"_\")\n",
    "        epoch = int(epoch)\n",
    "        #print(epoch)\n",
    "        identifier = f\"{model_name}/{task}_{run_id}_\"\n",
    "        \n",
    "        # Update the maximum epoch for the current task and run_id combination\n",
    "        if identifier not in max_epochs or epoch > max_epochs[identifier]:\n",
    "            max_epochs[identifier] = epoch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6f93b766-372d-41ed-9a8b-ac54568760f5",
   "metadata": {},
   "outputs": [],
   "source": [
    "max_epochs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5e029591-42f1-4938-ba5b-3209ea1ffd1d",
   "metadata": {},
   "outputs": [],
   "source": [
    "run_ids = [\"1\",\"2\"]\n",
    "task = \"NC\"\n",
    "models = ['GCN','SAGE','GAT','CGCN']"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e0efa8ca-52fc-4ac8-9a2a-6783785d06a1",
   "metadata": {},
   "source": [
    "## Sanity Tests"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c7c61ec8-8266-47b5-bbcc-03bd5b42ed12",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.cuda.empty_cache()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fabfae2c-bdce-4266-b6eb-6de908998ed0",
   "metadata": {},
   "outputs": [],
   "source": [
    "a = torch.tensor([[1,2,3],[4,5,6],[7,8,9]])\n",
    "b = torch.tensor([[1,2,3],[4,5,6],[7,8,10]])\n",
    "rtd(a,b)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a3290822-176f-466b-bf58-affaa97df91c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Iterate over models\n",
    "heatmap_holder = {}\n",
    "\n",
    "for model_name in models:\n",
    "    print(\"Currently running for: \",model_name)\n",
    "    heatmap_holder[model_name]={}\n",
    "    identifier1 = f'{model_name}/{task}_{run_ids[0]}_'\n",
    "    identifier2 = f'{model_name}/{task}_{run_ids[1]}_'\n",
    "    \n",
    "    fv_path1 = path+identifier1+str(max_epochs[identifier1])\n",
    "    fv_path2 = path+identifier2+str(max_epochs[identifier2])\n",
    "    print(fv_path1)\n",
    "    print(fv_path2)\n",
    "    A_data = np.load(fv_path1+\".npz\")\n",
    "    B_data = np.load(fv_path2+\".npz\")\n",
    "    A = dict(A_data)\n",
    "    B = dict(B_data)\n",
    "    l1 = len(A.keys())\n",
    "    l2 = len(B.keys())\n",
    "    if l1!=l2:\n",
    "        break\n",
    "    conv_layers = list(A.keys())\n",
    "    nsa_sample_indices = np.random.choice(range(A[conv_layers[0]].shape[0]),4000, replace=False)\n",
    "    rtd_sample_indices = np.random.choice(range(A[conv_layers[0]].shape[0]),400, replace=False)\n",
    "    cka_heatmap = np.empty((l1,l2))\n",
    "    rtd_heatmap = np.empty((l1,l2))\n",
    "    nsa_heatmap = np.empty((l1,l2))\n",
    "    for i,layer1 in enumerate(conv_layers):\n",
    "      for j,layer2 in enumerate(conv_layers):\n",
    "        print(\"Grid ID: \",i,j)\n",
    "        X = A[layer1]\n",
    "        Y = B[layer2]\n",
    "        # print(\"Calculating CKA\")\n",
    "        # cka_heatmap[i,j]=cka(gram_linear(X),gram_linear(Y))\n",
    "        print(\"Calculating NSA\")\n",
    "        nsa_heatmap[i,j]=nsa_criterion(torch.tensor(X[nsa_sample_indices],dtype=float), torch.tensor(Y[nsa_sample_indices],dtype=float))\n",
    "        print(nsa_heatmap[i,j])  \n",
    "        print(\"Calculating RTD\")\n",
    "        rtd_heatmap[i,j]=rtd(torch.tensor(X[rtd_sample_indices]),torch.tensor(Y[rtd_sample_indices]))\n",
    "        print(rtd_heatmap[i,j])\n",
    "    heatmap_holder[model_name]['CKA']=cka_heatmap\n",
    "    heatmap_holder[model_name]['NSA']=nsa_heatmap\n",
    "    heatmap_holder[model_name]['RTD']=rtd_heatmap"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7e31379a-72d1-42ab-bb55-308a14854734",
   "metadata": {},
   "outputs": [],
   "source": [
    "rtd_heatmap"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4c533a16-8545-4575-bd5f-ed2f7359c8d3",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "save_path = f'sanity_test_{dataset_name}_{task}_heatmap_holder.pkl'\n",
    "with open(save_path, 'wb') as file:\n",
    "    pickle.dump(heatmap_holder, file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6e99f291-3387-41b0-90c3-886134009841",
   "metadata": {},
   "outputs": [],
   "source": [
    "load_path = f'sanity_test_{dataset_name}_{task}_heatmap_holder.pkl'\n",
    "with open(load_path, 'rb') as file:\n",
    "    heatmap_holder = pickle.load(file)\n",
    "\n",
    "heatmap_holder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fa63a60d-975d-4c58-8a52-a61d8290e8c1",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "# Define the output directory where you want to save the heatmap images\n",
    "output_directory = f'heatmaps/{dataset_name}/sanity_tests/{task}/'\n",
    "!mkdir -p $output_directory\n",
    "\n",
    "# Create the output directory if it doesn't exist\n",
    "#os.makedirs(output_directory, exist_ok=True)\n",
    "\n",
    "# Iterate through the nested dictionary\n",
    "for model_name, metrics in heatmap_holder.items():\n",
    "    for metric_name, heatmap_data in metrics.items():\n",
    "        # Create a heatmap plot using seaborn\n",
    "        sns.set()\n",
    "        plt.figure(figsize=(10, 8))  # Adjust the figure size as needed\n",
    "        if metric_name=='CKA':\n",
    "            heatmap_data = 1 - heatmap_data\n",
    "            metric_name = \"CKA'\"\n",
    "        annot_kwargs = {\"fontsize\": 25}  # Adjust the fontsize as needed        \n",
    "        ax = sns.heatmap(heatmap_data, annot=True, annot_kws=annot_kwargs)  # Modify cmap and formatting as needed\n",
    "\n",
    "        plt.title(f'{model_name} - {metric_name} - {task} Sanity Test Heatmap', fontsize=23)\n",
    "        plt.xlabel(f'{model_name} layers', fontsize=25)\n",
    "        plt.ylabel(f'{model_name} layers', fontsize=25)\n",
    "        plt.xticks(fontsize=21)\n",
    "        plt.yticks(fontsize=21)\n",
    "        \n",
    "        # Save the heatmap plot as an image\n",
    "        #plt.show()\n",
    "        output_filename = f'{model_name}_{metric_name}_heatmap.png'\n",
    "        output_path = os.path.join(output_directory, output_filename)\n",
    "        plt.savefig(output_path, bbox_inches='tight')\n",
    "        plt.close()  # Close the plot to release resources\n",
    "\n",
    "        print(f'Saved: {output_path}')\n",
    "\n",
    "print('All heatmaps saved successfully.')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a0ddfb5b-310c-4c8e-b661-273795ad048e",
   "metadata": {},
   "source": [
    "## Cross Architecture Tests"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e819cc06-b49c-4b91-80bf-e04c885cf20e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Iterate over models\n",
    "heatmap_holder = {}\n",
    "n = len(models)  # Change this to your desired matrix size\n",
    "# Initialize an empty list to store the indices of the upper triangle\n",
    "\n",
    "# Loop through the rows and columns of the matrix\n",
    "for y in range(n):\n",
    "    for z in range(y+1, n):  # Start from i to avoid duplicates in the lower triangle\n",
    "        \n",
    "        identifier1 = f'{models[y]}/{task}_{run_ids[0]}_'\n",
    "        identifier2 = f'{models[z]}/{task}_{run_ids[0]}_'\n",
    "    \n",
    "        fv_path1 = path+identifier1+str(max_epochs[identifier1])\n",
    "        fv_path2 = path+identifier2+str(max_epochs[identifier2])\n",
    "        A_data = np.load(fv_path1+\".npz\")\n",
    "        B_data = np.load(fv_path2+\".npz\")\n",
    "        model_variant = models[y]+\"_\"+models[z]\n",
    "        print(fv_path1)\n",
    "        print(fv_path2)\n",
    "        print(\"Currently running for: \",model_variant)\n",
    "        heatmap_holder[model_variant]={}\n",
    "        A = dict(A_data)\n",
    "        B = dict(B_data)\n",
    "        l1 = len(A.keys())\n",
    "        l2 = len(B.keys())\n",
    "        if l1!=l2:\n",
    "            break\n",
    "        conv_layers = list(A.keys())\n",
    "        nsa_sample_indices = np.random.choice(range(A[conv_layers[0]].shape[0]),4000, replace=False)\n",
    "        rtd_sample_indices = np.random.choice(range(A[conv_layers[0]].shape[0]),400, replace=False)\n",
    "        cka_heatmap = np.empty((l1,l2))\n",
    "        rtd_heatmap = np.empty((l1,l2))\n",
    "        nsa_heatmap = np.empty((l1,l2))\n",
    "        for i,layer1 in enumerate(conv_layers):\n",
    "          for j,layer2 in enumerate(conv_layers):\n",
    "            print(\"Grid ID: \",i,j)\n",
    "            X = A[layer1]\n",
    "            Y = B[layer2]\n",
    "            print(\"Calculating CKA\")\n",
    "            cka_heatmap[i,j]=cka(gram_linear(X),gram_linear(Y))\n",
    "            print(\"Calculating NSA\")\n",
    "            nsa_heatmap[i,j]=nsa_criterion(torch.tensor(X[nsa_sample_indices],dtype=float), torch.tensor(Y[nsa_sample_indices],dtype=float))\n",
    "            print(\"Calculating RTD\")\n",
    "            rtd_heatmap[i,j]=rtd(X[rtd_sample_indices],Y[rtd_sample_indices])\n",
    "        heatmap_holder[model_variant]['CKA']=cka_heatmap\n",
    "        heatmap_holder[model_variant]['NSA']=nsa_heatmap\n",
    "        heatmap_holder[model_variant]['RTD']=rtd_heatmap"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2d6fd973-d29c-46d8-ba64-3c7c2db698aa",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "save_path = f'cross_arch_test_{dataset_name}_{task}_heatmap_holder.pkl'\n",
    "with open(save_path, 'wb') as file:\n",
    "    pickle.dump(heatmap_holder, file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b7ba1113",
   "metadata": {},
   "outputs": [],
   "source": [
    "load_path = f'cross_arch_test_{dataset_name}_{task}_heatmap_holder.pkl'\n",
    "with open(load_path, 'rb') as file:\n",
    "    heatmap_holder = pickle.load(file)\n",
    "\n",
    "heatmap_holder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "404d2ae8-137c-4c79-9b4b-86d590e9ce03",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "# Define the output directory where you want to save the heatmap images\n",
    "output_directory = f'heatmaps/{dataset_name}/cross_arch_tests/{task}/'\n",
    "!mkdir -p $output_directory\n",
    "\n",
    "# Create the output directory if it doesn't exist\n",
    "#os.makedirs(output_directory, exist_ok=True)\n",
    "\n",
    "# Iterate through the nested dictionary\n",
    "for model_name, metrics in heatmap_holder.items():\n",
    "    for metric_name, heatmap_data in metrics.items():\n",
    "        # Create a heatmap plot using seaborn\n",
    "        sns.set()\n",
    "        plt.figure(figsize=(10, 8))  # Adjust the figure size as needed\n",
    "        if metric_name=='CKA':\n",
    "            heatmap_data = 1 - heatmap_data\n",
    "            metric_name = \"CKA'\"\n",
    "        annot_kwargs = {\"fontsize\": 25}  # Adjust the fontsize as needed\n",
    "        \n",
    "        ax = sns.heatmap(heatmap_data, annot=True, annot_kws=annot_kwargs)  # Modify cmap and formatting as needed\n",
    "        model1, model2 = model_name.split('_')\n",
    "        #plt.title(f'{model1} vs {model2} - {task} - {metric_name} Heatmap', fontsize=23)\n",
    "        plt.xlabel(f'{model2} layers', fontsize=25) # X axis is always model 2\n",
    "        plt.ylabel(f'{model1} layers', fontsize=25) # Y axis is model 1 since the second loop iterates over columns\n",
    "        # Save the heatmap plot as an image\n",
    "        plt.xticks(fontsize=21)\n",
    "        plt.yticks(fontsize=21)\n",
    "        #plt.show()\n",
    "        output_filename = f'{model_name}_{metric_name}_heatmap.png'\n",
    "        output_path = os.path.join(output_directory, output_filename)\n",
    "        plt.savefig(output_path, bbox_inches='tight')\n",
    "        plt.close()  # Close the plot to release resources\n",
    "\n",
    "        print(f'Saved: {output_path}')\n",
    "\n",
    "print('All heatmaps saved successfully.')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "748b1b9c-7838-4867-aa49-b9d23d544857",
   "metadata": {},
   "source": [
    "## Cross Downstream Task Tests"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "199b99fa-3c42-48d7-93ba-be6b27158943",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Iterate over models\n",
    "heatmap_holder = {}\n",
    "tasks = [\"NC\",\"LP\"]\n",
    "n = len(models)  # Change this to your desired matrix size\n",
    "# Initialize an empty list to store the indices of the upper triangle\n",
    "\n",
    "# Loop through the rows and columns of the matrix\n",
    "for model in models:        \n",
    "        identifier1 = f'{model}/{tasks[0]}_{run_ids[0]}_'\n",
    "        identifier2 = f'{model}/{tasks[1]}_{run_ids[0]}_'\n",
    "    \n",
    "        fv_path1 = path+identifier1+str(max_epochs[identifier1])\n",
    "        fv_path2 = path+identifier2+str(max_epochs[identifier2])\n",
    "        A_data = np.load(fv_path1+\".npz\")\n",
    "        B_data = np.load(fv_path2+\".npz\")\n",
    "        print(fv_path1)\n",
    "        print(fv_path2)\n",
    "        print(\"Currently running for: \",model)\n",
    "        heatmap_holder[model]={}\n",
    "        A = dict(A_data)\n",
    "        B = dict(B_data)\n",
    "        l1 = len(A.keys())\n",
    "        l2 = len(B.keys())\n",
    "        if l1!=l2:\n",
    "            break\n",
    "        conv_layers = list(A.keys())\n",
    "        nsa_sample_indices = np.random.choice(range(A[conv_layers[0]].shape[0]),4000, replace=False)\n",
    "        rtd_sample_indices = np.random.choice(range(A[conv_layers[0]].shape[0]),250, replace=False)\n",
    "        cka_heatmap = np.empty((l1,l2))\n",
    "        rtd_heatmap = np.empty((l1,l2))\n",
    "        nsa_heatmap = np.empty((l1,l2))\n",
    "        for i,layer1 in enumerate(conv_layers):\n",
    "          for j,layer2 in enumerate(conv_layers):\n",
    "            print(\"Grid ID: \",i,j)\n",
    "            X = A[layer1]\n",
    "            Y = B[layer2]\n",
    "            print(\"Calculating CKA\")\n",
    "            cka_heatmap[i,j]=cka(gram_linear(X),gram_linear(Y))\n",
    "            print(\"Calculating NSA\")\n",
    "            nsa_heatmap[i,j]=nsa_criterion(torch.tensor(X[nsa_sample_indices],dtype=float), torch.tensor(Y[nsa_sample_indices],dtype=float))\n",
    "            print(\"Calculating RTD\")\n",
    "            rtd_heatmap[i,j]=rtd(X[rtd_sample_indices],Y[rtd_sample_indices])\n",
    "        heatmap_holder[model]['CKA']=cka_heatmap\n",
    "        heatmap_holder[model]['NSA']=nsa_heatmap\n",
    "        heatmap_holder[model]['RTD']=rtd_heatmap"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6b141601-0692-4c6c-a66d-609494a82691",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "save_path = f'cross_task_test_{dataset_name}_heatmap_holder.pkl'\n",
    "with open(save_path, 'wb') as file:\n",
    "    pickle.dump(heatmap_holder, file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "24a75c37-c8ba-455c-9198-dac947701fc8",
   "metadata": {},
   "outputs": [],
   "source": [
    "load_path = f'cross_task_test_{dataset_name}_heatmap_holder.pkl'\n",
    "with open(load_path, 'rb') as file:\n",
    "    heatmap_holder = pickle.load(file)\n",
    "\n",
    "heatmap_holder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e6bb6bf5-3b01-4732-961c-5238c76f1265",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "task_list = [\"Node Classification\",\"Link Prediction\"]\n",
    "# Define the output directory where you want to save the heatmap images\n",
    "output_directory = f'heatmaps/{dataset_name}/cross_task_tests/'\n",
    "!mkdir -p $output_directory\n",
    "\n",
    "# Create the output directory if it doesn't exist\n",
    "#os.makedirs(output_directory, exist_ok=True)\n",
    "\n",
    "# Iterate through the nested dictionary\n",
    "for model_name, metrics in heatmap_holder.items():\n",
    "    for metric_name, heatmap_data in metrics.items():\n",
    "        # Create a heatmap plot using seaborn\n",
    "        sns.set()\n",
    "        plt.figure(figsize=(10, 8))  # Adjust the figure size as needed\n",
    "        if metric_name=='CKA':\n",
    "            heatmap_data = 1 - heatmap_data\n",
    "            metric_name = \"CKA'\"\n",
    "        annot_kwargs = {\"fontsize\": 25}  # Adjust the fontsize as needed\n",
    "        \n",
    "        ax = sns.heatmap(heatmap_data, annot=True, annot_kws=annot_kwargs)  # Modify cmap and formatting as needed\n",
    "        plt.title(f'{model_name} - {metric_name} Cross Task Test Heatmap', fontsize=23)\n",
    "        plt.xlabel(f'{task_list[1]} Layers', fontsize=25) # X axis is always model 2\n",
    "        plt.ylabel(f'{task_list[0]} Layers', fontsize=25) # Y axis is model 1 since the second loop iterates over columns\n",
    "        # Save the heatmap plot as an image\n",
    "        plt.xticks(fontsize=21)\n",
    "        plt.yticks(fontsize=21)\n",
    "        #plt.show()\n",
    "        output_filename = f'{model_name}_{metric_name}_heatmap.png'\n",
    "        output_path = os.path.join(output_directory, output_filename)\n",
    "        plt.savefig(output_path, bbox_inches='tight')\n",
    "        plt.close()  # Close the plot to release resources\n",
    "\n",
    "        print(f'Saved: {output_path}')\n",
    "\n",
    "print('All heatmaps saved successfully.')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6d821533-bfcf-456e-8e27-06dc1ff49cf8",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "e20cd2d3-cef3-40dc-8fa9-d95607278082",
   "metadata": {},
   "source": [
    "## Convergence Tests"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9e80c629-4ceb-4ef7-98c8-bca80d9f9f1c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Iterate over models\n",
    "heatmap_holder = {}\n",
    "n = len(models)  # Change this to your desired matrix size\n",
    "# Initialize an empty list to store the indices of the upper triangle\n",
    "\n",
    "# Loop through the rows and columns of the matrix\n",
    "for model in models:        \n",
    "    identifier1 = f'{model}/{task}_{run_ids[0]}_'\n",
    "    identifier2 = f'{model}/{task}_{run_ids[0]}_'\n",
    "    heatmap=np.empty((4,max_epochs[identifier1]//5))\n",
    "    print(heatmap.shape)\n",
    "    heatmap_holder[model]={}\n",
    "    fv_path1 = path+identifier1+str(max_epochs[identifier1])\n",
    "    A_data = np.load(fv_path1+\".npz\")\n",
    "    print(fv_path1)\n",
    "    A = dict(A_data)\n",
    "    print(\"Currently running for: \",model)\n",
    "    conv_layers = list(A.keys())\n",
    "    nsa_sample_indices = np.random.choice(range(A[conv_layers[0]].shape[0]),4000, replace=False)\n",
    "    for i,epoch in enumerate(range(5,max_epochs[identifier1]+1,5)):\n",
    "        fv_path2 = path+identifier2+str(epoch)\n",
    "        print(fv_path2, epoch)\n",
    "        \n",
    "        B_data = np.load(fv_path2+\".npz\")\n",
    "        B = dict(B_data)\n",
    "        l1 = len(A.keys())\n",
    "        l2 = len(B.keys())\n",
    "        if l1!=l2:\n",
    "            break\n",
    "        conv_layers = list(A.keys())\n",
    "        for j,layer in enumerate(conv_layers):\n",
    "            print(\"Grid ID: \",j,i)\n",
    "            X = A[layer]\n",
    "            Y = B[layer]\n",
    "            print(\"Calculating NSA\")\n",
    "            heatmap[j,i]=nsa_criterion(torch.tensor(X[nsa_sample_indices],dtype=float), torch.tensor(Y[nsa_sample_indices],dtype=float))\n",
    "        heatmap_holder[model]['NSA']=heatmap"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4636d4bb-a9ad-4bc9-aae6-42d14a815b0d",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "# Define the output directory where you want to save the heatmap images\n",
    "output_directory = f'heatmaps/{dataset_name}/convergence_tests/{task}/'\n",
    "!mkdir -p $output_directory\n",
    "\n",
    "# Create the output directory if it doesn't exist\n",
    "#os.makedirs(output_directory, exist_ok=True)\n",
    "\n",
    "# Iterate through the nested dictionary\n",
    "for model_name, metrics in heatmap_holder.items():\n",
    "    for metric_name, heatmap_data in metrics.items():\n",
    "        # Create a heatmap plot using seaborn\n",
    "        sns.set()\n",
    "        if task=='NC':\n",
    "            plt.figure(figsize=(20, 8))  # Adjust the figure size as needed\n",
    "        else:\n",
    "            plt.figure(figsize=(40, 16))  # Adjust the figure size as needed\n",
    "        if metric_name=='CKA':\n",
    "            heatmap_data = 1 - heatmap_data\n",
    "            metric_name = \"CKA'\"\n",
    "        annot_kwargs = {\"fontsize\": 14}  # Adjust the fontsize as needed   \n",
    "        if task=='NC':\n",
    "            ax = sns.heatmap(heatmap_data, annot=True)  # Modify cmap and formatting as needed\n",
    "        else:\n",
    "            ax = sns.heatmap(heatmap_data, annot=False)  # Modify cmap and formatting as needed\n",
    "        plt.title(f'{model_name} - {metric_name} - {task} Convergence Test Heatmap', fontsize=23)\n",
    "        plt.xlabel(f'Epochs', fontsize=23) # X axis is always model 2\n",
    "        plt.ylabel(f'Layers of {model_name}', fontsize=23) # Y axis is model 1 since the second loop iterates over columns\n",
    "        xticks = plt.xticks()[0]\n",
    "        identifier = f'{model_name}/{task}_{1}_'\n",
    "        \n",
    "        plt.xticks(xticks, np.arange(5,max_epochs[identifier]+1,5))\n",
    "        # Save the heatmap plot as an image\n",
    "        #plt.show()\n",
    "        output_filename = f'{model_name}_{metric_name}_heatmap.png'\n",
    "        output_path = os.path.join(output_directory, output_filename)\n",
    "        plt.savefig(output_path, bbox_inches='tight')\n",
    "        plt.close()  # Close the plot to release resources\n",
    "\n",
    "        print(f'Saved: {output_path}')\n",
    "\n",
    "print('All heatmaps saved successfully.')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d00464ec-fe55-470d-9e31-980981c1cbac",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
