{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a66742f3",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import torch\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "\n",
    "from scipy import stats"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "63b6811e",
   "metadata": {},
   "outputs": [],
   "source": [
    "absolute_path = \"/\".join(os.path.abspath(os.getcwd()).split('/')[:-2])\n",
    "absolute_path"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b7efb3d7",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append(absolute_path)\n",
    "\n",
    "from utils.experiments import visualization, df_analysis"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7908a72b",
   "metadata": {},
   "outputs": [],
   "source": [
    "path_to_save = os.path.join(absolute_path, 'experiments/dump_results')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d75f18d3",
   "metadata": {},
   "outputs": [],
   "source": [
    "data = pd.read_csv(os.path.join(absolute_path, 'data/SemBias.csv'))\n",
    "data.head(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ede15249",
   "metadata": {},
   "outputs": [],
   "source": [
    "labels = data['label']\n",
    "labels.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bc7f0e22",
   "metadata": {},
   "outputs": [],
   "source": [
    "logs = 'logs/representations/nlp/downloaded/'\n",
    "dataset = 'sembias'\n",
    "model1_name = 'glove'\n",
    "model2_name = 'gp-gn-glove'\n",
    "\n",
    "path_m1 = os.path.join(absolute_path, logs, model1_name, dataset,  'all_magnitudes.pt')\n",
    "path_m2 = os.path.join(absolute_path, logs, model2_name, dataset,  'all_magnitudes.pt')\n",
    "# path_m2 = os.path.join(absolute_path, logs, model2_name, dataset,  'debiased_magnitudes.pt') # for gn-glove without gender dim\n",
    "path_m1, path_m2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "likely-interview",
   "metadata": {},
   "outputs": [],
   "source": [
    "sim_mx_logs = 'logs/similarity_matrix/pnka/nlp/downloaded/'\n",
    "dataset = 'sembias'\n",
    "folder = f'M1_{model1_name}_M2_{model2_name}_{dataset}'\n",
    "\n",
    "sim_mx = torch.load(\n",
    "    os.path.join(\n",
    "        absolute_logs_path, sim_mx_logs, folder, 'pnka.pt'))\n",
    "sim = torch.diag(sim_mx)\n",
    "sim.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "002b874d",
   "metadata": {},
   "outputs": [],
   "source": [
    "mag_m1 = torch.load(path_m1)\n",
    "mag_m2 = torch.load(path_m2)\n",
    "mag_m1.shape, mag_m2.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "angry-sellers",
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.DataFrame({'sim':sim, 'label':data['label'],f'mag_{model1_name}':mag_m1,\n",
    "                   f'mag_{model2_name}':mag_m2})\n",
    "df[f'd({model2_name}-{model1_name})'] = (np.array(mag_m2)-np.array(mag_m1))\n",
    "df[f'd({model2_name}-{model1_name})/{model1_name}'] = (np.array(mag_m2)-np.array(mag_m1)) / np.array(mag_m1)\n",
    "df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c1fd1179",
   "metadata": {},
   "outputs": [],
   "source": [
    "index_label0 = list(df[df['label']==0].index)\n",
    "index_label1 = list(df[df['label']==1].index)\n",
    "index_label2 = list(df[df['label']==2].index)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a1012947",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f'overall: {round(stats.pearsonr(mag_m1, mag_m2)[0], 4)} (gd: {round(stats.pearsonr(np.take(np.array(mag_m1), index_label0), np.take(np.array(mag_m2), index_label0))[0], 4)}; gn: {round(stats.pearsonr(np.take(np.array(mag_m1), index_label1), np.take(np.array(mag_m2), index_label1))[0], 4)}; gs: {round(stats.pearsonr(np.take(np.array(mag_m1), index_label2), np.take(np.array(mag_m2), index_label2))[0], 4)})')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "08683099",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_x = []\n",
    "all_x.append(list(df['sim'].iloc[index_label0]))\n",
    "all_x.append(list(df['sim'].iloc[index_label1]))\n",
    "all_x.append(list(df['sim'].iloc[index_label2]))\n",
    "\n",
    "all_y = []\n",
    "delta = list(df[f'd({model2_name}-{model1_name})/{model1_name}'].iloc[index_label0])\n",
    "all_y.append(delta)\n",
    "delta = list(df[f'd({model2_name}-{model1_name})/{model1_name}'].iloc[index_label1])\n",
    "all_y.append(delta)\n",
    "delta = list(df[f'd({model2_name}-{model1_name})/{model1_name}'].iloc[index_label2])\n",
    "all_y.append(delta)\n",
    "\n",
    "all_names = []\n",
    "all_names.append('Gender Definition') # 0\n",
    "all_names.append('Gender Neutral') # 1\n",
    "all_names.append('Gender Stereotype') # 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ab933df6",
   "metadata": {},
   "outputs": [],
   "source": [
    "xaxis_title_text='PNKA'\n",
    "yaxis_title_text=f'Percentage Difference'\n",
    "xaxis_title_text, yaxis_title_text"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7b776237",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "save_filename = f\"gfairness_scatter_plot_deltamagnitudes_vs_similarity_byclass_{model1_name}_{model2_name}_scaled.pdf\"\n",
    "save_filename"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9a718e87",
   "metadata": {},
   "outputs": [],
   "source": [
    "visualization.plot_scatter_plot(\n",
    "    all_x, all_y, all_names, \n",
    "    xaxis_title_text,\n",
    "    yaxis_title_text,\n",
    "    y_range=[-10, 30],\n",
    "    x_range=[0, 1],\n",
    "    font=dict(\n",
    "        family=\"Times New Roman\",\n",
    "        size=23,\n",
    "        color=\"Black\"\n",
    "    ),\n",
    "#     text_scatter=aux_text_scatter,\n",
    "    mode=\"markers+text\",\n",
    "    legend=dict(\n",
    "        yanchor=\"top\",\n",
    "        y=0.99,\n",
    "        xanchor=\"left\",\n",
    "        x=0.01,\n",
    "        title_font_family=\"Times New Roman\",\n",
    "        font=dict(size=25)\n",
    "    ),\n",
    "    showlegend=False,\n",
    "    xaxis = dict(\n",
    "        zerolinecolor='black', \n",
    "        zeroline=True,\n",
    "        zerolinewidth=2,\n",
    "            linecolor = \"black\",\n",
    "            spikedash = 'solid',\n",
    "            showgrid=True,\n",
    "            gridcolor='lightgrey'),\n",
    "    marker=dict(size=10),\n",
    "    colors=['salmon', 'mediumseagreen', 'cornflowerblue'],\n",
    "    save_path=os.path.join(path_to_save, save_filename)\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "chronic-anime",
   "metadata": {},
   "outputs": [],
   "source": [
    "model1_name"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1fca7de2",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_x = []\n",
    "all_x.append(np.take(np.array(sim), index_label0))\n",
    "all_x.append(np.take(np.array(sim), index_label1))\n",
    "all_x.append(np.take(np.array(sim), index_label2))\n",
    "\n",
    "all_y = []\n",
    "delta = np.take(np.array(mag_m1), index_label0)\n",
    "all_y.append(delta)\n",
    "delta = np.take(np.array(mag_m1), index_label1)\n",
    "all_y.append(delta)\n",
    "delta = np.take(np.array(mag_m1), index_label2)\n",
    "all_y.append(delta)\n",
    "\n",
    "all_names = []\n",
    "all_names.append('Gender Definition') # 0\n",
    "all_names.append('Gender Neutral') # 1\n",
    "all_names.append('Gender Stereotype') # 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "proud-teach",
   "metadata": {},
   "outputs": [],
   "source": [
    "visualization.plot_scatter_plot(\n",
    "    all_x, all_y, all_names, \n",
    "    xaxis_title_text, yaxis_title_text,\n",
    "    y_range=[-4, 6],\n",
    "    x_range=[0, 1],\n",
    "    font=dict(\n",
    "        family=\"Times New Roman\",\n",
    "        size=23,\n",
    "        color=\"Black\"\n",
    "    ),\n",
    "    legend=dict(\n",
    "        yanchor=\"top\",\n",
    "        y=0.99,\n",
    "        xanchor=\"left\",\n",
    "        x=0.01,\n",
    "        title_font_family=\"Times New Roman\",\n",
    "        font=dict(size=25)\n",
    "    ),\n",
    "#     legend=dict(\n",
    "#         orientation=\"h\",\n",
    "#         yanchor=\"bottom\",\n",
    "#         y=1.02,\n",
    "#         xanchor=\"right\",\n",
    "#         x=1\n",
    "#     ),\n",
    "    colors=['salmon', 'mediumseagreen', 'cornflowerblue'],\n",
    "#     save_path=os.path.join(path_to_save, save_filename)\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "threatened-horizon",
   "metadata": {},
   "outputs": [],
   "source": [
    "model2_name"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "imperial-tower",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_x = []\n",
    "all_x.append(np.take(np.array(sim), index_label0))\n",
    "all_x.append(np.take(np.array(sim), index_label1))\n",
    "all_x.append(np.take(np.array(sim), index_label2))\n",
    "\n",
    "all_y = []\n",
    "delta = np.take(np.array(mag_m2), index_label0)\n",
    "all_y.append(delta)\n",
    "delta = np.take(np.array(mag_m2), index_label1)\n",
    "all_y.append(delta)\n",
    "delta = np.take(np.array(mag_m2), index_label2)\n",
    "all_y.append(delta)\n",
    "\n",
    "all_names = []\n",
    "all_names.append('Gender Definition') # 0\n",
    "all_names.append('Gender Neutral') # 1\n",
    "all_names.append('Gender Stereotype') # 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "copyrighted-extreme",
   "metadata": {},
   "outputs": [],
   "source": [
    "visualization.plot_scatter_plot(\n",
    "    all_x, all_y, all_names, \n",
    "    xaxis_title_text, yaxis_title_text,\n",
    "    y_range=[-4, 6],\n",
    "    x_range=[0, 1],\n",
    "    font=dict(\n",
    "        family=\"Times New Roman\",\n",
    "        size=18,\n",
    "        color=\"Black\"\n",
    "    ),\n",
    "    legend=dict(\n",
    "        yanchor=\"top\",\n",
    "        y=0.99,\n",
    "        xanchor=\"left\",\n",
    "        x=0.01,\n",
    "        title_font_family=\"Times New Roman\",\n",
    "        font=dict(size= 18)\n",
    "    ),\n",
    "#     legend=dict(\n",
    "#         orientation=\"h\",\n",
    "#         yanchor=\"bottom\",\n",
    "#         y=1.02,\n",
    "#         xanchor=\"right\",\n",
    "#         x=1\n",
    "#     ),\n",
    "    colors=['salmon', 'mediumseagreen', 'cornflowerblue'],\n",
    "#     save_path=os.path.join(path_to_save, save_filename)\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fundamental-firmware",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "flying-ribbon",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "arbitrary-lesson",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
