{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "49c5b4c1",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import torch\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import torch.nn as nn\n",
    "\n",
    "from scipy import stats\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e09d1a03",
   "metadata": {},
   "outputs": [],
   "source": [
    "absolute_path = \"/\".join(os.path.abspath(os.getcwd()).split('/')[:-3])\n",
    "absolute_path"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "45a069b1",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append(absolute_path)\n",
    "\n",
    "from utils.experiments import visualization, df_analysis\n",
    "from utils.data import CLASS_ID_TO_NAMES"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d8e6bc57",
   "metadata": {},
   "outputs": [],
   "source": [
    "eval_path = 'logs/eval_models'\n",
    "path_to_save = os.path.join(absolute_path, 'experiments/dump_results')\n",
    "ks = [500, 1000, 2000]\n",
    "# ks = [5, 10, 100]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7c4be65b",
   "metadata": {},
   "outputs": [],
   "source": [
    "sim_mx_logs = 'logs/similarity_matrix/efficient_pnka/topk_landmarks_from_trainset'\n",
    "arch1 = 'r18'\n",
    "arch2 = arch1\n",
    "layer = 'l17'\n",
    "dataset = 'cifar10'\n",
    "\n",
    "seeds_even = np.array([0,0,1])\n",
    "seeds_odd = np.array([1,2,2])\n",
    "seeds_even, seeds_odd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "viral-pakistan",
   "metadata": {},
   "outputs": [],
   "source": [
    "path_m1 = []\n",
    "for seed1 in seeds_even:\n",
    "    if arch1 == 'r18':\n",
    "        if seed1 == 0 or seed1 == 1:\n",
    "            path_m1.append(absolute_path_r4)\n",
    "        else:\n",
    "            path_m1.append(absolute_path_r5)\n",
    "    if arch1 == 'vgg16' or arch1 == 'incep':\n",
    "        if seed1 == 0 or seed1 == 1:\n",
    "            path_m1.append(absolute_path_r4)\n",
    "        else:\n",
    "            path_m1.append(absolute_path_r5)\n",
    "            \n",
    "path_m2 = []\n",
    "for seed1 in seeds_odd:\n",
    "    if arch1 == 'r18':\n",
    "        if seed1 == 0 or seed1 == 1:\n",
    "            path_m2.append(absolute_path_r4)\n",
    "        else:\n",
    "            path_m2.append(absolute_path_r5)\n",
    "    if arch1 == 'vgg16' or arch1 == 'incep':\n",
    "        if seed1 == 0 or seed1 == 1:\n",
    "            path_m2.append(absolute_path_r4)\n",
    "        else:\n",
    "            path_m2.append(absolute_path_r5)\n",
    "path_m1, path_m2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "obvious-turkey",
   "metadata": {},
   "outputs": [],
   "source": [
    "input_data = f'{dataset}_testset'\n",
    "input_data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "academic-tribute",
   "metadata": {},
   "outputs": [],
   "source": [
    "sims = []\n",
    "cka_sims = []\n",
    "folder_names = []\n",
    "\n",
    "all_ratio_intersection, all_topk_neighbors_x, all_topk_neighbors_y = [], [], []\n",
    "for idx, (seed1,seed2) in enumerate(zip(seeds_even, seeds_odd)):\n",
    "    print(seed1, seed2)\n",
    "    model_name1 = f'{dataset}-{arch1}-seed{seed1}'\n",
    "    model_name2 = f'{dataset}-{arch2}-seed{seed2}'\n",
    "    sim_mx_folder = f'M1_{model_name1}_{layer}_M2_{model_name2}_{layer}'\n",
    "    folder_names.append(f'{seed1}_{seed2}')\n",
    "\n",
    "    sim_mx = torch.load(\n",
    "        os.path.join(\n",
    "            absolute_path, sim_mx_logs, sim_mx_folder, input_data, 'pnka.pt'))\n",
    "    cka_mx = torch.load(\n",
    "        os.path.join(\n",
    "            absolute_path, 'logs/similarity_matrix/cka/', sim_mx_folder, input_data, 'final_sim_xxtyyt.pt'))\n",
    "    norm = torch.load(\n",
    "        os.path.join(\n",
    "            absolute_path, 'logs/similarity_matrix/cka/', sim_mx_folder, input_data, 'norm.pt'))\n",
    "    cka_sims.append(torch.trace(cka_mx)/norm)\n",
    "    sim = sim_mx\n",
    "#     sim = torch.diag(sim_mx)\n",
    "    sims.append(sim)\n",
    "    \n",
    "    features_x = torch.load(os.path.join(path_m1[idx], 'logs/representations/', model_name1, f'{layer}.pt'))\n",
    "    features_y = torch.load(os.path.join(path_m2[idx], 'logs/representations/', model_name2, f'{layer}.pt'))\n",
    "\n",
    "    cos_features_x = df_analysis.get_cosine_similarity(features_x, features_x)\n",
    "    cos_features_y = df_analysis.get_cosine_similarity(features_y , features_y)\n",
    "\n",
    "    aux_all_ratio_intersection, aux_all_topk_neighbors_x, aux_all_topk_neighbors_y = [], [], []\n",
    "    for k in ks:\n",
    "        ratio_intersection, topk_neighbors_x, topk_neighbors_y = df_analysis.get_ratio_intersection_neighbors(\n",
    "            cos_features_x, cos_features_y, k, largest=True)\n",
    "        aux_all_ratio_intersection.append(ratio_intersection)\n",
    "        aux_all_topk_neighbors_x.append(topk_neighbors_x)\n",
    "        aux_all_topk_neighbors_y.append(topk_neighbors_y)\n",
    "    all_ratio_intersection.append(aux_all_ratio_intersection)\n",
    "    all_topk_neighbors_x.append(aux_all_topk_neighbors_x)\n",
    "    all_topk_neighbors_y.append(aux_all_topk_neighbors_y)\n",
    "        \n",
    "sims = torch.stack(sims)\n",
    "cka_sims = torch.stack(cka_sims)\n",
    "sims.shape, cka_sims.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "270c5fb8",
   "metadata": {},
   "outputs": [],
   "source": [
    "data = {}\n",
    "for folder_name, sim in zip(folder_names, sims):\n",
    "    data[folder_name] = sim\n",
    "df = pd.DataFrame(data)\n",
    "df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "substantial-louisville",
   "metadata": {},
   "outputs": [],
   "source": [
    "avg = df.mean(axis=1)\n",
    "std = df.std(axis=1)\n",
    "df['avg'] = avg\n",
    "df['std'] = std"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "systematic-malawi",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_ratio_intersection = torch.tensor(all_ratio_intersection)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "african-partnership",
   "metadata": {},
   "outputs": [],
   "source": [
    "avg_all_ratio_intersection = torch.mean(all_ratio_intersection, 0)\n",
    "std_all_ratio_intersection = torch.std(all_ratio_intersection, 0)\n",
    "avg_all_ratio_intersection.shape, std_all_ratio_intersection.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "selected-british",
   "metadata": {},
   "outputs": [],
   "source": [
    "for idx, (k, avg_ratio, std_ratio) in enumerate(zip(ks, avg_all_ratio_intersection, std_all_ratio_intersection)):\n",
    "    df[f'avg_ratio_{k}'] = avg_ratio\n",
    "    df[f'std_ratio_{k}'] = std_ratio"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "engaging-sunday",
   "metadata": {},
   "outputs": [],
   "source": [
    "cka_score = torch.mean(cka_sims)\n",
    "cka_score_std = torch.std(cka_sims)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "35b87ff0",
   "metadata": {},
   "source": [
    "=> order diagonal based on values"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4a36d11f",
   "metadata": {},
   "outputs": [],
   "source": [
    "sorted_df_sim = df_analysis.sort_dataframe(\n",
    "    df, 'avg', ascending=True)\n",
    "sorted_df_sim['relative_index'] = np.arange(sorted_df_sim.shape[0])\n",
    "sorted_df_sim.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5c43b7ac",
   "metadata": {},
   "source": [
    "### Box plot of similarity vs neighbors overlap"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "short-engine",
   "metadata": {},
   "outputs": [],
   "source": [
    "ranges = np.arange(0, 1.1, 0.1)\n",
    "ranges"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "supreme-measurement",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "split_df = []\n",
    "for x, y in zip(ranges, ranges[1:]):\n",
    "    split_df.append(sorted_df_sim[(sorted_df_sim['avg'] > x) & (sorted_df_sim['avg'] < y)])\n",
    "split_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "resident-tablet",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "for k in ks:\n",
    "    fig = visualization.plot_box_plot_with_bins(\n",
    "        split_df, column_name_x='avg', yrange=[0,1], #xrange=[0,1],\n",
    "        column_name_y=f'avg_ratio_{k}',\n",
    "        showlegend=False,\n",
    "        yaxis_title_text='Fraction of Overlap of Neighbors',\n",
    "        xaxis_title_text='Average PNKA',\n",
    "        font=dict(\n",
    "            family=\"Times New Roman\",\n",
    "            size=23,\n",
    "            color=\"Black\"\n",
    "        ),\n",
    "        xtick_angle=30,\n",
    "        save_path=os.path.join(path_to_save, f'box_plot_k{k}_{dataset}-{arch1}-{layer}-avgoverdiffseeds_top1000balancedlandmarksfromtrain.pdf')\n",
    "    )\n",
    "    fig.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "international-appointment",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "indian-calculator",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6c09cb36",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "broken-colony",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "foo",
   "language": "python",
   "name": "foo"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
