{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import torch\n",
    "import pandas as pd\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "absolute_path = \"/\".join(os.path.abspath(os.getcwd()).split('/')[:-2])\n",
    "absolute_path"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append(absolute_path)\n",
    "\n",
    "from utils.experiments import visualization, df_analysis"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = pd.read_csv(os.path.join(absolute_path, 'data', 'SemBias.csv'))\n",
    "data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "labels_to_groups = {\n",
    "    0: 'gender-defining',\n",
    "    1: 'gender-neutral',\n",
    "    2: 'gender-stereotype'\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = 'sembias'\n",
    "model1_name = 'glove'\n",
    "model2_name = 'gp-gn-glove'\n",
    "folder = f'M1_{model1_name}_M2_{model2_name}_{dataset}'\n",
    "folder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sim_mx_logs = 'logs/similarity_matrix/pnka/nlp/downloaded/'\n",
    "sim_mx = torch.load(\n",
    "    os.path.join(\n",
    "        absolute_path, sim_mx_logs, folder, 'pnka.pt'))\n",
    "sim = torch.diag(sim_mx)\n",
    "sim.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sim.sum() / sim.shape[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### CKA"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sim_mx_logs = 'logs/similarity_matrix/cka/nlp/downloaded/'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sim_mx = torch.load(\n",
    "    os.path.join(\n",
    "        absolute_path, sim_mx_logs, folder, 'final_sim_xxtyyt.pt'))\n",
    "norm = torch.load(\n",
    "    os.path.join(\n",
    "        absolute_path, sim_mx_logs, folder, 'norm.pt'))\n",
    "torch.trace(sim_mx) / norm"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### CKA for subgroups"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for group_id, group_name in labels_to_groups.items():\n",
    "    print(group_name)\n",
    "    folder = f'M1_{group_name}-{model1_name}_M2_{group_name}-{model2_name}_{dataset}'\n",
    "    sim_mx = torch.load(\n",
    "        os.path.join(\n",
    "            absolute_logs_path, sim_mx_logs, folder, 'final_sim_xxtyyt.pt'))\n",
    "    norm = torch.load(\n",
    "        os.path.join(\n",
    "            absolute_logs_path, sim_mx_logs, folder, 'norm.pt'))\n",
    "    print(f'CKA: {round((torch.trace(sim_mx) / norm).item(), 4)}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "(newadv2)",
   "language": "python",
   "name": "newadv2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
