{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "visible-church",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "import torch\n",
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "distributed-secret",
   "metadata": {},
   "outputs": [],
   "source": [
    "absolute_path = \"/\".join(os.path.abspath(os.getcwd()).split('/')[:-3])\n",
    "absolute_path"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "inside-bridges",
   "metadata": {},
   "outputs": [],
   "source": [
    "sim_mx_logs = 'logs/similarity_matrix/pnka/'\n",
    "arch1 = 'incep'\n",
    "arch2 = arch1\n",
    "layer = 'l17'\n",
    "dataset = 'cifar100'\n",
    "\n",
    "seeds_even = np.arange(0, 6, 2)\n",
    "seeds_odd = np.arange(1, 6, 2)\n",
    "seeds_even, seeds_odd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "anticipated-missile",
   "metadata": {},
   "outputs": [],
   "source": [
    "path = []\n",
    "for seed1 in seeds_even:\n",
    "    if arch1 == 'r18':\n",
    "        if seed1 == 0:\n",
    "            path.append(absolute_path)\n",
    "        else:\n",
    "            path.append(absolute_logs_path)\n",
    "    if arch1 == 'vgg16' or arch1 == 'incep':\n",
    "        if seed1 == 0:\n",
    "            path.append(absolute_path)\n",
    "        else:\n",
    "            path.append(absolute_logs_path)\n",
    "path"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "inclusive-candy",
   "metadata": {},
   "outputs": [],
   "source": [
    "sims = []\n",
    "cka_sims = []\n",
    "folder_names = []\n",
    "\n",
    "for idx, (seed1,seed2) in enumerate(zip(seeds_even, seeds_odd)):\n",
    "    print(seed1, seed2)\n",
    "    model_name1 = f'{dataset}-{arch1}-seed{seed1}'\n",
    "    model_name2 = f'{dataset}-{arch2}-seed{seed2}'\n",
    "    sim_mx_folder = f'M1_{model_name1}_{layer}_M2_{model_name2}_{layer}'\n",
    "    folder_names.append(f'{seed1}_{seed2}')\n",
    "    sim_mx = torch.load(\n",
    "        os.path.join(\n",
    "            path[idx], sim_mx_logs, sim_mx_folder, 'pnka.pt'))\n",
    "    cka_mx = torch.load(\n",
    "        os.path.join(\n",
    "            path[idx], 'logs/similarity_matrix/cka/', sim_mx_folder, 'final_sim_xxtyyt.pt'))\n",
    "    norm = torch.load(\n",
    "        os.path.join(\n",
    "            path[idx], 'logs/similarity_matrix/cka/', sim_mx_folder, 'norm.pt'))\n",
    "    cka_sims.append(torch.trace(cka_mx)/norm)\n",
    "    sim = torch.diag(sim_mx)\n",
    "    sims.append(sim)\n",
    "\n",
    "sims = torch.stack(sims)\n",
    "cka_sims = torch.stack(cka_sims)\n",
    "sims.shape, cka_sims.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "wanted-drill",
   "metadata": {},
   "outputs": [],
   "source": [
    "data = {}\n",
    "for folder_name, sim in zip(folder_names, sims):\n",
    "    data[folder_name] = sim\n",
    "df = pd.DataFrame(data)\n",
    "df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "overall-conservation",
   "metadata": {},
   "outputs": [],
   "source": [
    "avg = df.mean(axis=1)\n",
    "std = df.std(axis=1)\n",
    "df['avg'] = avg\n",
    "df['std'] = std"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "abstract-blood",
   "metadata": {},
   "outputs": [],
   "source": [
    "df['avg'].mean(), df['std'].mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "weekly-documentary",
   "metadata": {},
   "outputs": [],
   "source": [
    "cka_score = torch.mean(cka_sims)\n",
    "cka_score_std = torch.std(cka_sims)\n",
    "cka_score.item(), cka_score_std.item()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "lonely-lewis",
   "metadata": {},
   "outputs": [],
   "source": [
    "ours_avg = round(float(df['avg'].mean()),3)\n",
    "ours_std = round(float(df['std'].mean()),3)\n",
    "f\"{round(cka_score.item(), 3)} (±{round(cka_score_std.item(), 3)}) & {ours_avg} (±{ours_std})\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "loved-inclusion",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "powerful-insurance",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "(newadv2)",
   "language": "python",
   "name": "newadv2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
