{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e0349d41",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import torch\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "\n",
    "from scipy import stats"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "30175038",
   "metadata": {},
   "outputs": [],
   "source": [
    "absolute_path = \"/\".join(os.path.abspath(os.getcwd()).split('/')[:-3])\n",
    "absolute_path"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a89eac6e",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append(absolute_path)\n",
    "\n",
    "from utils.experiments import visualization, df_analysis"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "77651b4b",
   "metadata": {},
   "outputs": [],
   "source": [
    "path_to_save = os.path.join(absolute_path, 'experiments/dump_results')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1c6d9733",
   "metadata": {},
   "source": [
    "### Get similarities"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "908e6b77",
   "metadata": {},
   "outputs": [],
   "source": [
    "seed1 = str(0)\n",
    "arch1 = 'r18'\n",
    "\n",
    "dataset = 'cifar10'\n",
    "\n",
    "layer = 17\n",
    "sim_mx_path = 'logs/similarity_matrix/pnka/'\n",
    "\n",
    "sim_mx_folder = f'M1_{dataset}-{arch1}-seed{seed1}_l{layer}_l{layer}'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "furnished-probe",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torchvision\n",
    "\n",
    "transform_test = torchvision.transforms.Compose([\n",
    "    torchvision.transforms.ToTensor(),\n",
    "    torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),\n",
    "])\n",
    "\n",
    "testset = torchvision.datasets.CIFAR10(\n",
    "    root=f'{absolute_path}/data', train=False, download=False, transform=transform_test)\n",
    "testset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "rolled-stake",
   "metadata": {},
   "outputs": [],
   "source": [
    "features = torch.load(\n",
    "    os.path.join(absolute_path, 'logs/representations', f'{dataset}-{arch1}-seed{seed1}', f'l{layer}.pt'))\n",
    "features.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ongoing-anthropology",
   "metadata": {},
   "outputs": [],
   "source": [
    "class_names = list(testset.class_to_idx.keys())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "absolute-andrews",
   "metadata": {},
   "outputs": [],
   "source": [
    "import plotly.graph_objects as go"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "minor-patient",
   "metadata": {},
   "source": [
    "### Pie Chart"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "right-jumping",
   "metadata": {},
   "outputs": [],
   "source": [
    "# SEL_COLORS = ['lightblue',\n",
    "#   'lightred',\n",
    "#   'darkyellow',\n",
    "#   'rose',\n",
    "#   'orange',\n",
    "#   'green',\n",
    "#   'brown',\n",
    "#   'lillac',\n",
    "#   'salmon',\n",
    "#   'mediumvioletred',\n",
    "#   'lightgreen']\n",
    "# SEL_COLORS"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "adapted-fraction",
   "metadata": {},
   "outputs": [],
   "source": [
    "# neurons = [261, 202, 94]\n",
    "# neurons = np.arange(512)\n",
    "neurons = [1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "differential-chocolate",
   "metadata": {},
   "outputs": [],
   "source": [
    "for neuron_id in neurons:\n",
    "    sim_mx = torch.load(\n",
    "        os.path.join(\n",
    "            absolute_logs_path, sim_mx_path, sim_mx_folder, \n",
    "            f'remove_dimension/1/all', f'neuron_{neuron_id}',\n",
    "            'pnka.pt'))\n",
    "    sim = torch.diag(sim_mx)\n",
    "    df = pd.DataFrame({'sim':sim, 'labels':testset.targets})\n",
    "    sorted_df = df.sort_values('sim')[:100]\n",
    "#     print(sorted_df)\n",
    "    count = sorted_df.value_counts(['labels']).to_dict()\n",
    "\n",
    "    count_freq, name_class, labels_class = [], [], []\n",
    "    sum_count = 0\n",
    "    for idx, (label_id, label_count) in enumerate(list(count.items())):\n",
    "        labels_class.append(label_id[0])\n",
    "        count_freq.append(label_count)\n",
    "        name_class.append(class_names[label_id[0]])\n",
    "#         neurons_class_freq[class_names[label_id[0]]] = label_count\n",
    "\n",
    "    visualization.plot_pie_chart(\n",
    "        labels=name_class, values=count_freq, ids_colors=labels_class,\n",
    "#         title_text=f\"Neuron {neuron_id}\",\n",
    "        font=dict(\n",
    "            family=\"Times New Roman\",\n",
    "            size=18,\n",
    "            color=\"Black\"\n",
    "        ),\n",
    "#         title_font_family=\"Times New Roman\",\n",
    "        legend=dict(\n",
    "            yanchor=\"top\",\n",
    "            y=0.99,\n",
    "            xanchor=\"left\",\n",
    "            x=-0.15\n",
    "        ),\n",
    "#         title={\n",
    "#             'text': f'Neuron {neuron_id}',\n",
    "#             'font': dict(size=25),\n",
    "#             'y':0.9,\n",
    "#             'x':0.05,\n",
    "#             'xanchor': 'left',\n",
    "#             'yanchor': 'top'},\n",
    "        save_path=os.path.join(path_to_save, \n",
    "                               f'pie_distrclasses_{sim_mx_folder}_removingneuron{neuron_id}.pdf'),\n",
    "    )\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "correct-globe",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "(newadv2)",
   "language": "python",
   "name": "newadv2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
