{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Plot histogram"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import torch\n",
    "import pandas as pd\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "absolute_path = \"/\".join(os.path.abspath(os.getcwd()).split('/')[:-3])\n",
    "absolute_path"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append(absolute_path)\n",
    "\n",
    "from utils.experiments import visualization, df_analysis"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "eval_path = 'logs/eval_models'\n",
    "path_to_save = os.path.join(absolute_path, 'experiments/dump_results')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sim_mx_logs = 'logs/similarity_matrix/pnka/'\n",
    "arch1 = 'r18'\n",
    "arch2 = arch1\n",
    "layer = 'l17'\n",
    "dataset = 'cifar100'\n",
    "\n",
    "seeds_even = np.arange(0, 6, 2)\n",
    "seeds_odd = np.arange(1, 6, 2)\n",
    "seeds_even, seeds_odd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "path = []\n",
    "for seed1 in seeds_even:\n",
    "#     if arch1 == 'r18':\n",
    "#         if dataset == 'cifar10':\n",
    "#             if seed1 == 0:\n",
    "#                 path.append(absolute_path)\n",
    "#             else:\n",
    "#                 path.append(absolute_logs_path)\n",
    "#     if arch1 == 'vgg16' or arch1 == 'incep':\n",
    "    if seed1 == 0:\n",
    "        path.append(absolute_path)\n",
    "    else:\n",
    "        path.append(absolute_logs_path)\n",
    "path"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torchvision\n",
    "\n",
    "transform_test = torchvision.transforms.Compose([\n",
    "    torchvision.transforms.ToTensor(),\n",
    "    torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),\n",
    "])\n",
    "\n",
    "testset = torchvision.datasets.CIFAR10(\n",
    "    root=f'{absolute_path}/data', train=False, download=False, transform=transform_test)\n",
    "\n",
    "labels = torch.tensor(testset.targets)\n",
    "labels.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sims = []\n",
    "cka_sims = []\n",
    "folder_names = []\n",
    "all_preds = []\n",
    "correct_m1, correct_m2 = [], []\n",
    "for idx, (seed1,seed2) in enumerate(zip(seeds_even, seeds_odd)):\n",
    "    print(seed1, seed2)\n",
    "    model_name1 = f'{dataset}-{arch1}-seed{seed1}'\n",
    "    model_name2 = f'{dataset}-{arch2}-seed{seed2}'\n",
    "    \n",
    "    sim_mx_folder = f'M1_{model_name1}_{layer}_M2_{model_name2}_{layer}'\n",
    "    folder_names.append(f'{seed1}_{seed2}')\n",
    "    sim_mx = torch.load(\n",
    "        os.path.join(\n",
    "            path[idx], sim_mx_logs, sim_mx_folder, 'pnka.pt'))\n",
    "    cka_mx = torch.load(\n",
    "        os.path.join(\n",
    "            path[idx], 'logs/similarity_matrix/cka/', sim_mx_folder, 'final_sim_xxtyyt.pt'))\n",
    "    norm = torch.load(\n",
    "        os.path.join(\n",
    "            path[idx], 'logs/similarity_matrix/cka/', sim_mx_folder, 'norm.pt'))\n",
    "    cka_sims.append(torch.trace(cka_mx)/norm)\n",
    "    sim = torch.diag(sim_mx)\n",
    "    sims.append(sim)\n",
    "    \n",
    "    preds1 = torch.load(\n",
    "        os.path.join(\n",
    "            absolute_path, eval_path, model_name1, 'preds.pt'))\n",
    "    preds2 = torch.load(\n",
    "        os.path.join(\n",
    "            absolute_path, eval_path, model_name2, 'preds.pt'))\n",
    "    all_preds.append(preds1.int())\n",
    "    all_preds.append(preds2.int())\n",
    "#     correct_m1.append(torch.eq(preds1, labels).int())\n",
    "#     correct_m2.append(torch.eq(preds2, labels).int())\n",
    "\n",
    "sims = torch.stack(sims)\n",
    "all_preds = torch.stack(all_preds)\n",
    "cka_sims = torch.stack(cka_sims)\n",
    "sims.shape, cka_sims.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "all_preds.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "correctness = []\n",
    "correct_m1m2, correct_m3m4, correct_m5m6 = [], [], []\n",
    "for idx, (preds, label) in enumerate(zip(all_preds.T, labels)):\n",
    "#     print(preds, label, torch.all(preds == label))\n",
    "    correctness.append(torch.all(preds == label).int())\n",
    "    correct_m1m2.append(torch.eq(preds[0], preds[1]).int())\n",
    "    correct_m3m4.append(torch.eq(preds[2], preds[3]).int())\n",
    "    correct_m5m6.append(torch.eq(preds[4], preds[5]).int())\n",
    "correctness = torch.tensor(correctness)\n",
    "correct_m1m2 = torch.tensor(correct_m1m2)\n",
    "correct_m3m4 = torch.tensor(correct_m3m4)\n",
    "correct_m5m6 = torch.tensor(correct_m5m6)\n",
    "correctness.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = {}\n",
    "for folder_name, sim in zip(folder_names, sims):\n",
    "    data[folder_name] = sim\n",
    "df = pd.DataFrame(data)\n",
    "df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "avg = df.mean(axis=1)\n",
    "std = df.std(axis=1)\n",
    "df['avg'] = avg\n",
    "df['std'] = std"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "max(df['std']), min(df['std']), df['std'].mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df['all_correct'] = correctness\n",
    "df['correct_m1m2'] = correct_m1m2\n",
    "df['correct_m3m4'] = correct_m3m4\n",
    "df['correct_m5m6'] = correct_m5m6"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df['ratio_correct_preds'] = df[['correct_m1m2','correct_m3m4','correct_m5m6']].mean(axis=1)\n",
    "df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cka_score = torch.mean(cka_sims)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sim_agree = df[df['all_correct'] == 1]['avg']\n",
    "sim_disagree = df[df['all_correct'] != 1]['avg']\n",
    "sim_agree.shape, sim_disagree.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Overall Histogram"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "filename = os.path.join(path_to_save, f'hist_bycorrectness_{dataset}-{arch1}-{layer}_avgoverdiffseeds_nbinsx10.pdf')\n",
    "filename"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "# visualization.plot_histogram(\n",
    "# #     all_x=[df['avg']],\n",
    "#     all_x=[sim_agree, sim_disagree],\n",
    "#     all_names=['Correct Prediction', 'Incorrect Prediction'],\n",
    "#     histnorm='percent',\n",
    "# #     histfunc='sum',\n",
    "#     yaxis_title_text='Percentage of Points',\n",
    "#     xaxis_title_text='Average PNKA',\n",
    "#     nbinsx=10,\n",
    "#     yrange=[0,100],\n",
    "#     xrange=[0.,1],\n",
    "# #     xaxis=dict(\n",
    "# #         dtick=0.1,\n",
    "# #         linecolor = \"black\"),\n",
    "#     xaxis = dict(\n",
    "#         dtick=0.1,\n",
    "#         linecolor = \"black\",\n",
    "#         spikedash = 'solid',\n",
    "#         showgrid=True,\n",
    "#         gridcolor='lightgrey',\n",
    "#     ),\n",
    "# #     yaxis=dict(\n",
    "# #             linecolor = \"black\",),\n",
    "#     yaxis = dict(\n",
    "#         linecolor = \"black\",\n",
    "#         spikedash = 'solid',\n",
    "#         showgrid=True,\n",
    "#         gridcolor='lightgrey',\n",
    "#     ),\n",
    "#     font=dict(\n",
    "#         family=\"Times New Roman\",\n",
    "#         size=23,\n",
    "#         color=\"Black\"\n",
    "#     ),\n",
    "#     legend=dict(\n",
    "#         yanchor=\"top\",\n",
    "#         y=0.99,\n",
    "#         xanchor=\"left\",\n",
    "#         x=0.01,\n",
    "#         title_font_family=\"Times New Roman\",\n",
    "#         font=dict(size= 25)\n",
    "#     ),\n",
    "#     save_path=filename,\n",
    "# #     x_vline=cka_score,\n",
    "# #     linecolor='orange',\n",
    "# #     vline_annotation=f'CKA={round(cka_score.item(),2)}',\n",
    "# )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sim_agree = df[df['ratio_correct_preds'] == 1]['avg']\n",
    "sim_disagree = df[df['ratio_correct_preds'] != 1]['avg']\n",
    "sim_agree.shape, sim_disagree.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "visualization.plot_histogram(\n",
    "#     all_x=[df['avg']],\n",
    "    all_x=[sim_agree, sim_disagree],\n",
    "    all_names=['Correct Prediction', 'Incorrect Prediction'],\n",
    "    histnorm='percent',\n",
    "#     histfunc='sum',\n",
    "    yaxis_title_text='Percentage of Points',\n",
    "    xaxis_title_text='Average PNKA',\n",
    "    nbinsx=10,\n",
    "    yrange=[0,100],\n",
    "    xrange=[0.,1],\n",
    "#     xaxis=dict(\n",
    "#         dtick=0.1,\n",
    "#         linecolor = \"black\"),\n",
    "    xaxis = dict(\n",
    "        dtick=0.1,\n",
    "        linecolor = \"black\",\n",
    "        spikedash = 'solid',\n",
    "        showgrid=True,\n",
    "        gridcolor='lightgrey',\n",
    "    ),\n",
    "#     yaxis=dict(\n",
    "#             linecolor = \"black\",),\n",
    "    yaxis = dict(\n",
    "        linecolor = \"black\",\n",
    "        spikedash = 'solid',\n",
    "        showgrid=True,\n",
    "        gridcolor='lightgrey',\n",
    "    ),\n",
    "    font=dict(\n",
    "        family=\"Times New Roman\",\n",
    "        size=23,\n",
    "        color=\"Black\"\n",
    "    ),\n",
    "    legend=dict(\n",
    "        yanchor=\"top\",\n",
    "        y=0.99,\n",
    "        xanchor=\"left\",\n",
    "        x=0.01,\n",
    "        title_font_family=\"Times New Roman\",\n",
    "        font=dict(size= 25)\n",
    "    ),\n",
    "    save_path=filename,\n",
    "#     x_vline=cka_score,\n",
    "#     linecolor='orange',\n",
    "#     vline_annotation=f'CKA={round(cka_score.item(),2)}',\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "(newadv2)",
   "language": "python",
   "name": "newadv2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
