{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import torch\n",
    "import pandas as pd\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "absolute_path = \"/\".join(os.path.abspath(os.getcwd()).split('/')[:-3])\n",
    "absolute_path"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append(absolute_path)\n",
    "\n",
    "from utils.experiments import visualization, df_analysis"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "eval_path = 'logs/eval_models'\n",
    "path_to_save = os.path.join(absolute_path, 'experiments/dump_results')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sim_mx_logs = 'logs/similarity_matrix/pnka/diff_archs'\n",
    "arch1 = 'vgg16'\n",
    "layer = 'l19'\n",
    "\n",
    "arch2 = 'incep'\n",
    "layer2 = 'l17'\n",
    "\n",
    "dataset = 'cifar100'\n",
    "\n",
    "# seeds_even = np.arange(0, 6, 2)\n",
    "# seeds_odd = np.arange(1, 6, 2)\n",
    "seeds = np.array([0,1,2])\n",
    "# seeds_odd = np.arange(1, 6, 2)\n",
    "seeds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sims = []\n",
    "cka_sims = []\n",
    "folder_names = []\n",
    "for idx, seed in enumerate(seeds):\n",
    "    sim_mx_folder = f'M1_{dataset}-{arch1}-seed{seed}_{layer}_M2_{dataset}-{arch2}-seed{seed}_{layer2}'\n",
    "    folder_names.append(f'{seed}_{seed}')\n",
    "    sim_mx = torch.load(\n",
    "        os.path.join(\n",
    "            absolute_path, sim_mx_logs, sim_mx_folder, f'{dataset}_testset', 'pnka.pt'))\n",
    "    cka_mx = torch.load(\n",
    "        os.path.join(\n",
    "            absolute_path, 'logs/similarity_matrix/cka/diff_archs', sim_mx_folder, f'{dataset}_testset', 'final_sim_xxtyyt.pt'))\n",
    "    norm = torch.load(\n",
    "        os.path.join(\n",
    "            absolute_path, 'logs/similarity_matrix/cka/diff_archs', sim_mx_folder,f'{dataset}_testset',  'norm.pt'))\n",
    "    cka_sims.append(torch.trace(cka_mx)/norm)\n",
    "    sim = torch.diag(sim_mx)\n",
    "    sims.append(sim)\n",
    "\n",
    "sims = torch.stack(sims)\n",
    "cka_sims = torch.stack(cka_sims)\n",
    "sims.shape, cka_sims.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "filename = os.path.join(path_to_save, f'hist_{dataset}-{arch1}{arch2}-avgoverdiffseeds.pdf')\n",
    "filename"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = {}\n",
    "for folder_name, sim in zip(folder_names, sims):\n",
    "    data[folder_name] = sim\n",
    "df = pd.DataFrame(data)\n",
    "df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "avg = df.mean(axis=1)\n",
    "std = df.std(axis=1)\n",
    "df['avg'] = avg\n",
    "df['std'] = std"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cka_score = torch.mean(cka_sims)\n",
    "cka_score_std = torch.std(cka_sims)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ranges = np.arange(0, 1.1, 0.001)\n",
    "ranges"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "split_df = []\n",
    "for x, y in zip(ranges, ranges[1:]):\n",
    "    split_df.append(df[(df['avg'] >= x) & (df['avg'] < y)])\n",
    "len(split_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "count_points_split_df = []\n",
    "avg_sim_split_df = []\n",
    "std_sim_split_df = []\n",
    "for split in split_df:\n",
    "    count_points_split_df.append(len(split))\n",
    "    if len(split) == 0:\n",
    "        avg_sim_split_df.append(0)\n",
    "        std_sim_split_df.append(0)\n",
    "    else:\n",
    "        avg_sim_split_df.append(split['avg'].mean())\n",
    "        std_sim_split_df.append(split['std'].mean())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Overall Histogram"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df[f'{seed}_{seed}'].mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "agg_pnka = []\n",
    "std_agg_pnka = []\n",
    "for seed in seeds:\n",
    "    agg_pnka.append(df[f'{seed}_{seed}'].mean())\n",
    "    std_agg_pnka.append(df[f'{seed}_{seed}'].std())\n",
    "\n",
    "# agg_pnka = torch.sum(torch.tensor(df['avg'])) / len(df['avg'])\n",
    "# agg_pnka\n",
    "agg_pnka = torch.tensor(agg_pnka).mean()\n",
    "std_agg_pnka = torch.tensor(std_agg_pnka).mean()\n",
    "\n",
    "agg_pnka, std_agg_pnka"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "filename"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "fig = visualization.plot_histogram(\n",
    "    all_x=[df['avg']],\n",
    "    all_names=['count'],\n",
    "    yaxis_title_text='Count',\n",
    "    xaxis_title_text='Average PNKA',\n",
    "#     histnorm='probability density',\n",
    "#     error_y=dict(type='data', array=df['std']),\n",
    "    log=False,\n",
    "    xaxis=dict(\n",
    "            linecolor = \"black\"),\n",
    "    yaxis=dict(\n",
    "            linecolor = \"black\",),\n",
    "    xrange=[0,1],\n",
    "#     yrange=[0,400],\n",
    "    font=dict(\n",
    "        family=\"Times New Roman\",\n",
    "        size=23,\n",
    "        color=\"Black\"\n",
    "    ),\n",
    "    showlegend=False,\n",
    "    save_path=filename,\n",
    "    x_vline=cka_score,\n",
    "    linecolor='orange',\n",
    "    vline_annotation=f'CKA={round(cka_score.item(),2)}(±{round(cka_score_std.item(),2)})',\n",
    ")\n",
    "# fig.add_vline(x=agg_pnka, line_width=2, line_dash='dash', line_color='green',\n",
    "#               annotation_text=r\"$\\bar{PNKA}$\" + f\"={round(agg_pnka.item(),2)}(±{round(std_agg_pnka.item(),2)}\",\n",
    "#               annotation_position='top')\n",
    "fig.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "max(df['std']), min(df['std']), df['std'].mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "foo",
   "language": "python",
   "name": "foo"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
