{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Plot histogram"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import torch\n",
    "import pandas as pd\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import plotly.express as px\n",
    "COLORS = px.colors.qualitative.Plotly\n",
    "import plotly.graph_objects as go"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "absolute_path = \"/\".join(os.path.abspath(os.getcwd()).split('/')[:-2])\n",
    "absolute_path"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append(absolute_path)\n",
    "\n",
    "from utils.experiments import visualization, df_analysis"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "eval_path = 'logs/eval_models'\n",
    "path_to_save = os.path.join(absolute_path, 'experiments/dump_results')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sim_mx_logs = 'logs/similarity_matrix/efficient_pnka/topk_landmarks_from_trainset'\n",
    "arch1 = 'incep'\n",
    "arch2 = arch1\n",
    "layer = 'l17'\n",
    "dataset = 'cifar100'\n",
    "\n",
    "seeds_even = np.array([0,0,1])\n",
    "seeds_odd = np.array([1,2,2])\n",
    "seeds_even, seeds_odd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "nbinsx = 10\n",
    "nbs_perturbed = np.array([1000,3000,5000])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "perturbation_type = 'blur'\n",
    "# perturbation_type = 'elastic'\n",
    "# perturbation_type = 'color_jitter'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for nb_perturbed in nbs_perturbed:\n",
    "    print(nb_perturbed)\n",
    "    sims = []\n",
    "    cka_sims = []\n",
    "    folder_names = []\n",
    "    perturbed_idxs = []\n",
    "    for idx, (seed1,seed2) in enumerate(zip(seeds_even, seeds_odd)):\n",
    "        print(seed1, seed2)\n",
    "        sim_mx_folder = f'M1_{dataset}-{arch1}-seed{seed1}_{layer}_M2_{dataset}-{arch2}-seed{seed2}_{layer}'\n",
    "        folder_names.append(f'{seed1}_{seed2}')\n",
    "\n",
    "        attacked_params_foldername = ''\n",
    "        if perturbation_type == 'adv':\n",
    "            attacked_params_foldername = f'{dataset}-{arch1}-seed{seed1}/eps1.0_atlr0.5_atsteps100'\n",
    "        perturbed_path = os.path.join(\n",
    "            'ood_xp', perturbation_type, f'{nb_perturbed}_perturbed', attacked_params_foldername)\n",
    "        print(perturbed_path)\n",
    "        sim_mx = torch.load(\n",
    "            os.path.join(\n",
    "                absolute_path, sim_mx_logs, perturbed_path, sim_mx_folder, 'pnka.pt'))\n",
    "        perturbed_idx = torch.load(\n",
    "            os.path.join(\n",
    "                absolute_path, sim_mx_logs, perturbed_path, sim_mx_folder, 'perturbed_indexes.pt'))\n",
    "        perturbed_idxs.append(perturbed_idx)\n",
    "\n",
    "        sim = sim_mx\n",
    "        if len(sim.shape) == 2:\n",
    "            sim = torch.diag(sim)\n",
    "        sims.append(sim)\n",
    "\n",
    "    sims = torch.stack(sims)\n",
    "\n",
    "    \n",
    "    data = {}\n",
    "    for folder_name, sim in zip(folder_names, sims):\n",
    "        data[folder_name] = sim\n",
    "    df = pd.DataFrame(data)\n",
    "\n",
    "    avg = df.mean(axis=1)\n",
    "    std = df.std(axis=1)\n",
    "    df['avg'] = avg\n",
    "    df['std'] = std\n",
    "    df['perturbed'] = [1 if idx in perturbed_idxs[0] else 0 for idx in df.index]\n",
    "\n",
    "    sim_perturbed = df[df['perturbed'] == 1]['avg']\n",
    "    sim_notperturbed = df[df['perturbed'] == 0]['avg']\n",
    "\n",
    "    range_points = int(len(df) / 10)\n",
    "    total_bins = int(len(df) / range_points)\n",
    "\n",
    "    \n",
    "    ascending = True\n",
    "    df_sorted = df.sort_values('avg', ascending=ascending)\n",
    "    df_sorted = df_sorted.reset_index()\n",
    "\n",
    "    split_df = []\n",
    "    perturbed_split_df = []\n",
    "    notperturbed_split_df = []\n",
    "\n",
    "    avg_pnka, std_pnka = [], []\n",
    "    perturbed_percentage_per_bin = []\n",
    "    notperturbed_percentage_per_bin = []\n",
    "\n",
    "    for bin_idx in range(total_bins):\n",
    "        partial_df_sorted = df_sorted.iloc[bin_idx*range_points:(bin_idx*range_points)+range_points]\n",
    "        partial_df_perturbed = partial_df_sorted[partial_df_sorted['perturbed'] == 1] # perturbed\n",
    "        partial_df_notperturbed = partial_df_sorted[partial_df_sorted['perturbed'] == 0]\n",
    "        print(len(partial_df_sorted), 'p', (len(partial_df_perturbed), len(partial_df_perturbed)/len(partial_df_sorted),\n",
    "                                      'np', len(partial_df_notperturbed),len(partial_df_notperturbed)/len(partial_df_sorted)))\n",
    "\n",
    "        perturbed_percentage_per_bin.append(len(partial_df_perturbed)/len(partial_df_sorted))\n",
    "        notperturbed_percentage_per_bin.append(len(partial_df_notperturbed)/len(partial_df_sorted))\n",
    "        avg_pnka.append(partial_df_sorted['avg'].mean())\n",
    "        std_pnka.append(partial_df_sorted['avg'].mean())\n",
    "\n",
    "    range_for_plot = np.arange(total_bins)\n",
    "\n",
    "    filename = os.path.join(path_to_save, f'hist_{nb_perturbed}{perturbation_type}_{dataset}-{arch1}-{layer}_ranking-top{nb_landmarks}balancedlandmarksfromtrain.pdf')\n",
    "    \n",
    "    fig = go.Figure(data=[\n",
    "        go.Bar(name='Perturbed Points',\n",
    "               x=range_for_plot,\n",
    "               y=perturbed_percentage_per_bin,\n",
    "               marker_color=COLORS[1]),\n",
    "        go.Bar(name='Unperturbed Points',\n",
    "               x=range_for_plot,\n",
    "               y=notperturbed_percentage_per_bin,\n",
    "               marker_color=COLORS[2]),\n",
    "    ])\n",
    "\n",
    "\n",
    "    fig.update_layout(\n",
    "        xaxis_title_text=f'Ranking of PNKA score ({range_points} points in each bin)',\n",
    "        yaxis_title_text=f'Percentage',\n",
    "        barmode='stack',\n",
    "        yaxis = dict(\n",
    "            linecolor = \"black\",\n",
    "            spikedash = 'solid',\n",
    "            showgrid=True,\n",
    "            gridcolor='lightgrey'),\n",
    "        xaxis = dict(\n",
    "            linecolor = \"black\",\n",
    "            spikedash = 'solid',\n",
    "            showgrid=True,\n",
    "            gridcolor='lightgrey',\n",
    "            tickmode = 'array',\n",
    "            tickvals = range_for_plot,\n",
    "    #         ticktext = x_axis_names\n",
    "        ),\n",
    "        font=dict(\n",
    "                family=\"Times New Roman\",\n",
    "                size=23,\n",
    "                color=\"Black\"\n",
    "            ),\n",
    "        legend=dict(\n",
    "            orientation=\"h\",\n",
    "            yanchor=\"top\",\n",
    "            y=1.2,\n",
    "            xanchor=\"left\",\n",
    "            x=0.01\n",
    "        ),\n",
    "        paper_bgcolor='rgba(0,0,0,0)',\n",
    "        plot_bgcolor='rgba(0,0,0,0)',\n",
    "    )\n",
    "    fig.update_yaxes(range=[0, 1])\n",
    "\n",
    "\n",
    "    for x_vline_idx, (x_vline_item, avg_score) in enumerate(zip(range_for_plot, avg_pnka)):\n",
    "        score = round(avg_score,3)\n",
    "        fig.add_vline(x=x_vline_item, line_width=2, line_dash='dash',\n",
    "                      line_color='black', annotation_text=str(score),\n",
    "                      annotation=dict(font_size=16, font_family=\"Times New Roman\"),\n",
    "                      annotation_position='top')\n",
    "\n",
    "    fig.write_image(filename)\n",
    "    print(f'Saved in {filename}')\n",
    "    fig.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "foo",
   "language": "python",
   "name": "foo"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
