{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Plot histogram"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import torch\n",
    "import pandas as pd\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "absolute_path = \"/\".join(os.path.abspath(os.getcwd()).split('/')[:-2])\n",
    "absolute_path"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append(absolute_path)\n",
    "\n",
    "from utils.experiments import visualization, df_analysis"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "eval_path = 'logs/eval_models'\n",
    "path_to_save = os.path.join(absolute_path, 'experiments/dump_results')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sim_mx_logs = 'logs/similarity_matrix/efficient_pnka/topk_landmarks_from_trainset'\n",
    "arch1 = 'r18'\n",
    "arch2 = arch1\n",
    "layer = 'l17'\n",
    "dataset = 'cifar100'\n",
    "\n",
    "seeds_even = np.array([0,0,1])\n",
    "seeds_odd = np.array([1,2,2])\n",
    "seeds_even, seeds_odd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "perturbation_type = 'blur'\n",
    "nb_perturbed = str(1000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sims = []\n",
    "cka_sims = []\n",
    "folder_names = []\n",
    "perturbed_idxs = []\n",
    "for idx, (seed1,seed2) in enumerate(zip(seeds_even, seeds_odd)):\n",
    "    print(seed1, seed2)\n",
    "    sim_mx_folder = f'M1_{dataset}-{arch1}-seed{seed1}_{layer}_M2_{dataset}-{arch2}-seed{seed2}_{layer}'\n",
    "    folder_names.append(f'{seed1}_{seed2}')\n",
    "    \n",
    "    attacked_params_foldername = ''\n",
    "    if perturbation_type == 'adv':\n",
    "        attacked_params_foldername = f'{dataset}-{arch1}-seed{seed1}/eps1.0_atlr0.5_atsteps100'\n",
    "    perturbed_path = os.path.join(\n",
    "        'ood_xp', perturbation_type, f'{nb_perturbed}_perturbed', attacked_params_foldername)\n",
    "    print(perturbed_path)\n",
    "    sim_mx = torch.load(\n",
    "        os.path.join(\n",
    "            absolute_path, sim_mx_logs, perturbed_path, sim_mx_folder, 'pnka.pt'))\n",
    "    perturbed_idx = torch.load(\n",
    "        os.path.join(\n",
    "            absolute_path, sim_mx_logs, perturbed_path, sim_mx_folder, 'perturbed_indexes.pt'))\n",
    "    perturbed_idxs.append(perturbed_idx)\n",
    "\n",
    "#     cka_mx = torch.load(\n",
    "#         os.path.join(\n",
    "#             absolute_path, 'logs/similarity_matrix/cka/', sim_mx_folder,\n",
    "#             perturbed_path, 'final_sim_xxtyyt.pt'))\n",
    "#     norm = torch.load(\n",
    "#         os.path.join(\n",
    "#             absolute_path, 'logs/similarity_matrix/cka/', sim_mx_folder,\n",
    "#             perturbed_path, 'norm.pt'))\n",
    "#     cka_sims.append(torch.trace(cka_mx)/norm)\n",
    "    sim = sim_mx\n",
    "    if len(sim.shape) == 2:\n",
    "        sim = torch.diag(sim)\n",
    "    sims.append(sim)\n",
    "\n",
    "sims = torch.stack(sims)\n",
    "# cka_sims = torch.stack(cka_sims)\n",
    "sims.shape#, cka_sims.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.equal(perturbed_idxs[0], perturbed_idxs[1]), torch.equal(perturbed_idxs[0], perturbed_idxs[2])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "perturbed_idxs[0].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = {}\n",
    "for folder_name, sim in zip(folder_names, sims):\n",
    "    data[folder_name] = sim\n",
    "df = pd.DataFrame(data)\n",
    "df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "avg = df.mean(axis=1)\n",
    "std = df.std(axis=1)\n",
    "df['avg'] = avg\n",
    "df['std'] = std"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df['perturbed'] = [1 if idx in perturbed_idxs[0] else 0 for idx in df.index]\n",
    "df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "min(df['avg'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "len(df[df['perturbed'] == 1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# cka_score = torch.mean(cka_sims)\n",
    "# cka_score_std = torch.std(cka_sims)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Overall Histogram"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "nbinsx = 10"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Histogram per correctness"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sim_perturbed = df[df['perturbed'] == 1]['avg']\n",
    "sim_notperturbed = df[df['perturbed'] == 0]['avg']\n",
    "sim_perturbed.shape, sim_notperturbed.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import plotly.express as px\n",
    "COLORS = px.colors.qualitative.Plotly"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "nb_landmarks = 1000\n",
    "filename = os.path.join(path_to_save, f'hist_{nb_perturbed}{perturbation_type}_{dataset}-{arch1}-{layer}_nbinsx{nbinsx}-top{nb_landmarks}balancedlandmarksfromtrain.pdf')\n",
    "filename"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "visualization.plot_histogram(\n",
    "    all_x=[sim_perturbed, sim_notperturbed],\n",
    "    all_names=['Perturbed', 'Not Perturbed'],\n",
    "    histnorm='percent',\n",
    "#     histfunc='sum',\n",
    "    colors=COLORS[1:],\n",
    "    yaxis_title_text='Percentage of Points',\n",
    "    xaxis_title_text='Average PNKA',\n",
    "    nbinsx=20,\n",
    "    yrange=[0,100],\n",
    "#     log=True,\n",
    "    xrange=[-1,1],\n",
    "    xaxis = dict(\n",
    "        linecolor = \"black\",\n",
    "        spikedash = 'solid',\n",
    "        showgrid=True,\n",
    "        gridcolor='lightgrey',\n",
    "    ),\n",
    "#     yaxis = dict(\n",
    "#         linecolor = \"black\",\n",
    "#         showgrid=False,),\n",
    "    font=dict(\n",
    "        family=\"Times New Roman\",\n",
    "        size=23,\n",
    "        color=\"Black\"\n",
    "    ),\n",
    "    showlegend=False,\n",
    "    legend=dict(\n",
    "        yanchor=\"top\",\n",
    "        y=0.99,\n",
    "        xanchor=\"left\",\n",
    "        x=0.0,\n",
    "        font=dict(size=25)\n",
    "    ),\n",
    "#     x_vline=cka_score,\n",
    "#     linecolor='orange',\n",
    "#     vline_annotation=f'CKA={round(cka_score.item(),2)}(±{round(cka_score_std.item(),2)})',\n",
    "    save_path=os.path.join(path_to_save, filename)\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "range_points = int(len(df) / 10)\n",
    "total_bins = int(len(df) / range_points)\n",
    "total_bins"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import plotly.express as px\n",
    "COLORS = px.colors.qualitative.Plotly\n",
    "COLORS[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ascending = True\n",
    "df_sorted = df.sort_values('avg', ascending=ascending)\n",
    "df_sorted = df_sorted.reset_index()\n",
    "df_sorted"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "split_df = []\n",
    "perturbed_split_df = []\n",
    "notperturbed_split_df = []\n",
    "\n",
    "avg_pnka, std_pnka = [], []\n",
    "perturbed_percentage_per_bin = []\n",
    "notperturbed_percentage_per_bin = []\n",
    "\n",
    "for bin_idx in range(total_bins):\n",
    "    partial_df_sorted = df_sorted.iloc[bin_idx*range_points:(bin_idx*range_points)+range_points]\n",
    "    partial_df_perturbed = partial_df_sorted[partial_df_sorted['perturbed'] == 1] # perturbed\n",
    "    partial_df_notperturbed = partial_df_sorted[partial_df_sorted['perturbed'] == 0]\n",
    "    print(len(partial_df_sorted), 'p', (len(partial_df_perturbed), len(partial_df_perturbed)/len(partial_df_sorted),\n",
    "                                  'np', len(partial_df_notperturbed),len(partial_df_notperturbed)/len(partial_df_sorted)))\n",
    "\n",
    "    perturbed_percentage_per_bin.append(len(partial_df_perturbed)/len(partial_df_sorted))\n",
    "    notperturbed_percentage_per_bin.append(len(partial_df_notperturbed)/len(partial_df_sorted))\n",
    "    avg_pnka.append(partial_df_sorted['avg'].mean())\n",
    "    std_pnka.append(partial_df_sorted['avg'].mean())\n",
    "#     split_df.append(partial_df_sorted)\n",
    "#     perturbed_split_df.append(partial_df_perturbed)\n",
    "#     notperturbed_split_df.append(partial_df_notperturbed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# range_for_plot = np.concatenate((np.arange(-0.95, 0.0, 0.1), np.arange(0.05, 1, 0.1)), 0)\n",
    "# range_for_plot, len(range_for_plot)\n",
    "range_for_plot = np.arange(total_bins)\n",
    "range_for_plot, len(range_for_plot)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "nb_landmarks = 1000\n",
    "filename = os.path.join(path_to_save, f'hist_{nb_perturbed}{perturbation_type}_{dataset}-{arch1}-{layer}_ranking-top{nb_landmarks}balancedlandmarksfromtrain.pdf')\n",
    "filename"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import plotly.graph_objects as go\n",
    "\n",
    "fig = go.Figure(data=[\n",
    "    go.Bar(name='Perturbed Points',\n",
    "           x=range_for_plot,\n",
    "           y=perturbed_percentage_per_bin,\n",
    "           marker_color=COLORS[1]),\n",
    "    go.Bar(name='Unperturbed Points',\n",
    "           x=range_for_plot,\n",
    "           y=notperturbed_percentage_per_bin,\n",
    "           marker_color=COLORS[2]),\n",
    "])\n",
    "\n",
    "\n",
    "fig.update_layout(\n",
    "    xaxis_title_text=f'Ranking of PNKA score ({range_points} points in each bin)',\n",
    "    paxis_title_text=f'Percentage',\n",
    "    barmode='stack',\n",
    "    yaxis = dict(\n",
    "        linecolor = \"black\",\n",
    "        spikedash = 'solid',\n",
    "        showgrid=True,\n",
    "        gridcolor='lightgrey'),\n",
    "    xaxis = dict(\n",
    "        linecolor = \"black\",\n",
    "        spikedash = 'solid',\n",
    "        showgrid=True,\n",
    "        gridcolor='lightgrey',\n",
    "        tickmode = 'array',\n",
    "        tickvals = range_for_plot,\n",
    "#         ticktext = x_axis_names\n",
    "    ),\n",
    "    font=dict(\n",
    "            family=\"Times New Roman\",\n",
    "            size=23,\n",
    "            color=\"Black\"\n",
    "        ),\n",
    "    legend=dict(\n",
    "        orientation=\"h\",\n",
    "        yanchor=\"top\",\n",
    "        y=1.2,\n",
    "        xanchor=\"left\",\n",
    "        x=0.01\n",
    "    ),\n",
    "    paper_bgcolor='rgba(0,0,0,0)',\n",
    "    plot_bgcolor='rgba(0,0,0,0)',\n",
    ")\n",
    "fig.update_yaxes(range=[0, 1])\n",
    "\n",
    "\n",
    "for x_vline_idx, (x_vline_item, avg_score) in enumerate(zip(range_for_plot, avg_pnka)):\n",
    "    score = round(avg_score,3)\n",
    "    fig.add_vline(x=x_vline_item, line_width=2, line_dash='dash',\n",
    "                  line_color='black', annotation_text=str(score),\n",
    "                  annotation=dict(font_size=17, font_family=\"Times New Roman\"),\n",
    "                  annotation_position='top')\n",
    "\n",
    "fig.write_image(filename)\n",
    "print(f'Saved in {filename}')\n",
    "fig.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "len(range_for_plot), len(avg_pnka)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
