{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import torch\n",
    "import pandas as pd\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "absolute_path = \"/\".join(os.path.abspath(os.getcwd()).split('/')[:-2])\n",
    "absolute_path"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append(absolute_path)\n",
    "\n",
    "from utils.experiments import visualization, df_analysis"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "path_to_save = os.path.join(absolute_path, 'experiments/dump_results')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sim_mx_logs = 'logs/similarity_matrix/efficient_pnka/topk_landmarks_from_trainset'\n",
    "arch1 = 'r18'\n",
    "# seed1 = str(0)\n",
    "\n",
    "arch2 = arch1\n",
    "# seed2 = str(1)\n",
    "dataset = 'cifar10'\n",
    "\n",
    "layer = str(17)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "xp_folders = ['cifar10_testset', 'cifar100_testset', 'noise']\n",
    "xp_names = ['Non Robust Model', 'Robust Model']\n",
    "xp_idx = 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "nb_landmarks = 1000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# seeds_even = np.arange(0, 6, 2)\n",
    "# seeds_odd = np.arange(1, 6, 2)\n",
    "# seeds_even, seeds_odd\n",
    "seeds_1 = np.array([0,0,1])\n",
    "seeds_2 = np.array([1,2,2])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "all_sims_r = []\n",
    "all_sims_nr = []\n",
    "folder_names = []\n",
    "\n",
    "for idx, (seed1, seed2) in enumerate(zip(seeds_1, seeds_2)):\n",
    "    sim_mx_folder = f'M1_{dataset}-{arch1}-seed{seed1}_l{layer}_M2_{dataset}-{arch2}-seed{seed2}_l{layer}'\n",
    "    folder_names.append(f'{seed1}_{seed2}')\n",
    "\n",
    "    sim = torch.load(\n",
    "        os.path.join(\n",
    "            absolute_path,\n",
    "            sim_mx_logs, \n",
    "            sim_mx_folder,\n",
    "            xp_folders[xp_idx],\n",
    "            'pnka.pt'))\n",
    "    if len(sim.shape) == 2:\n",
    "        sim = torch.diag(sim)\n",
    "    all_sims_nr.append(sim)\n",
    "    print(f'Sim {sim.sum() / sim.shape[0]}')\n",
    "\n",
    "    sim_mx_folder = f'M1_rob-{dataset}-{arch1}-seed{seed1}-eps1_l{layer}_M2_rob-{dataset}-{arch2}-seed{seed2}-eps1_l{layer}'\n",
    "    sim = torch.load(\n",
    "        os.path.join(\n",
    "            absolute_path,\n",
    "            sim_mx_logs, \n",
    "            sim_mx_folder,\n",
    "            xp_folders[xp_idx],\n",
    "            'pnka.pt'))\n",
    "    if len(sim.shape) == 2:\n",
    "        sim = torch.diag(sim)\n",
    "    print(f'Sim {sim.sum() / sim.shape[0]}')\n",
    "    all_sims_r.append(sim)\n",
    "    \n",
    "all_sims_r = torch.stack(all_sims_r)\n",
    "all_sims_nr = torch.stack(all_sims_nr)\n",
    "all_sims_r.shape, all_sims_nr.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "xp_folders[xp_idx]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_r, data_nr = {}, {}\n",
    "for folder_name, sim_r, sim_nr in zip(folder_names, all_sims_r, all_sims_nr):\n",
    "    data_r[folder_name] = sim_r\n",
    "    data_nr[folder_name] = sim_nr\n",
    "\n",
    "df_r = pd.DataFrame(data_r)\n",
    "df_nr = pd.DataFrame(data_nr)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_statistics_df(df):\n",
    "    avg = df.mean(axis=1)\n",
    "    std = df.std(axis=1)\n",
    "    df['avg'] = avg\n",
    "    df['std'] = std\n",
    "    return df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_r = get_statistics_df(df_r)\n",
    "df_nr = get_statistics_df(df_nr)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Overall Histogram"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import plotly.graph_objects as go\n",
    "import plotly.express as px\n",
    "COLORS = px.colors.qualitative.Plotly"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "filename = os.path.join(path_to_save,\n",
    "                        f'hist_robustvsnonrobust_{dataset}-{arch1}_{xp_folders[xp_idx]}-top{nb_landmarks}balancedlandmarksfromtrain.pdf')\n",
    "filename"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_r['avg'].sort_values(), df_nr['avg'].sort_values()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = go.Figure()\n",
    "\n",
    "fig.add_trace(go.Histogram(x=df_nr['avg'], name='Non-Robust', autobinx=False, marker_color=COLORS[0]))\n",
    "fig.add_trace(go.Histogram(x=df_r['avg'], name='Robust', autobinx=False, marker_color=COLORS[1]))\n",
    "\n",
    "fig.update_layout(\n",
    "    xaxis_title_text='PNKA',\n",
    "    yaxis_title_text='Number of Points',\n",
    "    xaxis = dict(\n",
    "            linecolor = \"black\",\n",
    "            spikedash = 'solid',\n",
    "            showgrid=True,\n",
    "            gridcolor='lightgrey'),\n",
    "    paper_bgcolor='rgba(0,0,0,0)',\n",
    "    plot_bgcolor='rgba(0,0,0,0)',\n",
    "    yaxis = dict(\n",
    "            linecolor = \"black\",\n",
    "            spikedash = 'solid',\n",
    "            showgrid=True,\n",
    "            gridcolor='lightgrey'),\n",
    "    showlegend=True,\n",
    "    legend=dict(\n",
    "        yanchor=\"top\",\n",
    "        y=0.99,\n",
    "        xanchor=\"left\",\n",
    "        x=0.01\n",
    "    ),\n",
    "    font=dict(\n",
    "        family=\"Times New Roman\",\n",
    "        size=25,\n",
    "        color=\"Black\"\n",
    "    ),)\n",
    "\n",
    "fig.update_xaxes(range=[-0.5,1])\n",
    "fig.update_layout(barmode='overlay',            \n",
    "    xaxis = dict(\n",
    "        tickmode = 'linear',\n",
    "        tick0 = -0.5,\n",
    "        dtick = 0.1\n",
    "    ))\n",
    "fig.update_traces(opacity=0.85)\n",
    "\n",
    "fig.write_image(filename)\n",
    "print(f'Saved in {filename}')\n",
    "\n",
    "fig.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
