{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_ut = np.load(\"/home3/name/what-is-brainscore/temp_data_all/temp_data_pereira/X_gpt2-large-untrained-sp-hfgpt_0.npz\")\n",
    "X_t = np.load(\"/home3/name/what-is-brainscore/temp_data_all/temp_data_pereira/X_gpt2-large-sp-hfgpt.npz\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_ut.keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "BIL = X_ut['encoder.h.0']\n",
    "static = X_ut['embedding+pos']\n",
    "BIL_trained = X_t['encoder.h.21']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dl = np.load('/home3/name/what-is-brainscore/temp_data_all/data_labels_pereira.npy')\n",
    "y = []\n",
    "for d in dl:\n",
    "    if '243' in d:\n",
    "        y.append(0)\n",
    "    if '384' in d:\n",
    "        y.append(1)\n",
    "        \n",
    "y = np.array(y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "text_labels = np.load('/home3/name/what-is-brainscore/temp_data_all/text_by_labels_pereira.npy')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.linear_model import LogisticRegressionCV"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "idxs_243 = np.argwhere(y==0)\n",
    "idxs_384 = np.argwhere(y==1)\n",
    "np.random.shuffle(idxs_243)\n",
    "np.random.shuffle(idxs_384)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def return_idxs(dl, exp, pn):\n",
    "    idxs = []\n",
    "    for j, d in enumerate(dl):\n",
    "        if exp in d:\n",
    "            if int(d[-1])==int(pn):\n",
    "                idxs.append(j)\n",
    "                \n",
    "    return np.array(idxs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluate_exp_decoding(X,y,dl):\n",
    "    score_folds = []\n",
    "    exp = {'243': [0,1,2], '384':[0,1,2,3]}\n",
    "    for e, pn in exp.items():\n",
    "        for p in pn:\n",
    "            test_idxs = return_idxs(dl, e, p)\n",
    "            train_idxs = np.setdiff1d(np.arange(627), test_idxs)\n",
    "            X_train = X[train_idxs].squeeze()\n",
    "            X_test = X[test_idxs].squeeze()\n",
    "            y_train = y[train_idxs].squeeze()\n",
    "            y_test = y[test_idxs].squeeze()\n",
    "            clf = LogisticRegressionCV(cv=5, random_state=0, max_iter=1000).fit(X_train, y_train)\n",
    "            score = clf.score(X_test, y_test)\n",
    "            score_folds.append(score)\n",
    "            \n",
    "    return score_folds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "scores_BIL = evaluate_exp_decoding(BIL, y, dl)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "scores_static = evaluate_exp_decoding(static, y, dl)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "scores_BIL_trained = evaluate_exp_decoding(BIL_trained, y, dl)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(scores_BIL, scores_static)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.mean(scores_BIL)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.mean(scores_BIL_trained)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_names = np.hstack((np.repeat('Trained', 7), np.repeat('Untrained', 7)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "exp_decode = pd.DataFrame({'Accuracy': np.hstack((scores_BIL_trained, scores_BIL)), 'Model':model_names})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import seaborn as sns\n",
    "from matplotlib import pyplot as plt\n",
    "\n",
    "sns.barplot(data=exp_decode, y='Accuracy', x='Model')\n",
    "sns.despine()\n",
    "plt.ylim(0.5, 1)\n",
    "plt.ylabel(\"Accuracy\", fontsize=30)\n",
    "plt.yticks([0.5, 1], fontsize=30)\n",
    "plt.xticks(fontsize=30)\n",
    "plt.xlabel('')\n",
    "plt.savefig('/home3/name/what-is-brainscore/figures/pereira/exp-decode.pdf')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "llama",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.4"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
