{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Plot histogram"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import torch\n",
    "import pandas as pd\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "absolute_path = \"/\".join(os.path.abspath(os.getcwd()).split('/')[:-2])\n",
    "absolute_path"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append(absolute_path)\n",
    "\n",
    "from utils.experiments import visualization, df_analysis"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "path_to_save = os.path.join(absolute_path, 'experiments/dump_results')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = pd.read_csv(os.path.join(absolute_path, 'data/SemBias.csv'))\n",
    "data.head(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "len(data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sim_mx_logs = 'logs/similarity_matrix/pnka/nlp/downloaded/'\n",
    "dataset = 'sembias'\n",
    "model1_name = 'glove'\n",
    "model2_name = 'gn-glove'\n",
    "# model2_name = 'gp-glove'\n",
    "folder = f'M1_{model1_name}_M2_{model2_name}_{dataset}'\n",
    "\n",
    "sim_mx = torch.load(\n",
    "    os.path.join(\n",
    "        absolute_logs_path, sim_mx_logs, folder, 'pnka.pt'))\n",
    "sim = torch.diag(sim_mx)\n",
    "sim.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cka_sim_mx_logs = 'logs/similarity_matrix/cka/nlp/downloaded/'\n",
    "sim_mx_cka = torch.load(\n",
    "    os.path.join(\n",
    "        absolute_logs_path, cka_sim_mx_logs, folder, 'final_sim_xxtyyt.pt'))\n",
    "norm = torch.load(\n",
    "    os.path.join(\n",
    "        absolute_logs_path, cka_sim_mx_logs, folder, 'norm.pt'))\n",
    "sim_mx_cka.shape, norm.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cka_score = torch.trace(sim_mx_cka) / norm\n",
    "cka_score"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_sim = df_analysis.create_df(sim=sim, labels=data['label'])\n",
    "print(len(df_sim))\n",
    "df_sim.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Overall Histogram"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "nbinsx = 10"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Histogram per correctness"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "sim_genderdef = df_sim[df_sim['labels'] == 0]['sim']\n",
    "sim_genderneutral = df_sim[df_sim['labels'] == 1]['sim']\n",
    "sim_genderstereo = df_sim[df_sim['labels'] == 2]['sim']\n",
    "sim_genderdef.shape, sim_genderneutral.shape, sim_genderstereo.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "round(cka_score.item(),2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "visualization.plot_histogram(\n",
    "    all_x=[sim_genderdef, sim_genderneutral, sim_genderstereo],\n",
    "    all_names=['Gender Definition', 'Gender Neutral', 'Gender Stereotype'],\n",
    "    histnorm='percent',\n",
    "    yaxis_title_text='Percentage of Points',\n",
    "    xaxis_title_text='PNKA',\n",
    "    marker_patterns=['.', '-', 'x'],\n",
    "    colors=['salmon', 'mediumseagreen', 'cornflowerblue'],\n",
    "    yrange=[0,100],\n",
    "    xrange=[0,1],\n",
    "    xbins=dict(\n",
    "        start=0,\n",
    "        end=1,\n",
    "        size=0.1\n",
    "    ),\n",
    "    font=dict(\n",
    "        family=\"Times New Roman\",\n",
    "        size=23,\n",
    "        color=\"Black\"\n",
    "    ),\n",
    "    legend=dict(\n",
    "        yanchor=\"top\",\n",
    "        y=0.99,\n",
    "        xanchor=\"left\",\n",
    "        x=0.01,\n",
    "        title_font_family=\"Times New Roman\",\n",
    "        font=dict(size=25)\n",
    "    ),\n",
    "    showlegend=False,\n",
    "#     legend=dict(\n",
    "#         orientation=\"h\",\n",
    "#         yanchor=\"bottom\",\n",
    "#         y=1.02,\n",
    "#         xanchor=\"right\",\n",
    "#         x=1\n",
    "#     ),\n",
    "    x_vline=cka_score,\n",
    "    linecolor='orange',\n",
    "    vline_annotation=f'CKA={round(cka_score.item(),2)}',\n",
    "    save_path=os.path.join(path_to_save, f'gfairness_hist_byclass_{folder}_nbinsx{nbinsx}.pdf')\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
