{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "base = '/home3/name/what-is-brainscore/'\n",
    "from matplotlib import pyplot as plt\n",
    "import os\n",
    "from sklearn.metrics import mean_squared_error\n",
    "import sys\n",
    "sys.path.append('/home3/name/what-is-brainscore/')\n",
    "from helper_funcs import *\n",
    "from plotting_funcs import *\n",
    "from scipy.stats import pearsonr\n",
    "base = '/home3/name/what-is-brainscore/results_all/'\n",
    "import seaborn as sns\n",
    "import pandas as pd\n",
    "import statsmodels.formula.api as smf\n",
    "import matplotlib\n",
    "from scipy.stats import ttest_rel, ttest_1samp\n",
    "import statsmodels.formula.api as smf\n",
    "import nibabel as nib\n",
    "from nilearn import plotting\n",
    "from nilearn import surface\n",
    "from nilearn import datasets\n",
    "import plotly"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "resultsFolder_fed = f'{base}results_fedorenko/'\n",
    "figurePath = '/home3/name/what-is-brainscore/figures/fed/'\n",
    "dataset = 'fedorenko'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "find_best_layer('gpt2-large', resultsFolder_fed, required_str=['layer'], exclude_str=['hfgpt'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "find_best_layer('roberta-large', resultsFolder_fed, required_str=['layer'], exclude_str=['hfgpt'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_test_perf_across_layers(model_arr, dataset, layers_range, layer_name_arr, best_layer, saveName, \n",
    "                                 figurePath, resultsFolder, c, yticks):\n",
    "    \n",
    "    '''\n",
    "    :param dict model_arr: list containing model names \n",
    "    :param str dataset: which dataset to load data from \n",
    "    :param int layers_range: number of layers in model \n",
    "    :param str title: plot title \n",
    "    :param str layer_name_arr: name of layer to load \n",
    "    :param list best_layer: best layer of each model in model_arr\n",
    "    :param str saveName: where to save model \n",
    "    :param str figurePath: where to save figures\n",
    "    :param str resultsFolder: where to retrieve results from\n",
    "    :param list c: colors for each model \n",
    "    :param list yticks: yticks to plot\n",
    "    '''\n",
    "\n",
    "    counter = 0\n",
    "    \n",
    "    plt.figure(figsize=(14,8))\n",
    "    \n",
    "    for model, layer_range, layer_name, bl in zip(model_arr, layers_range, layer_name_arr, best_layer):\n",
    "        r2_layer = []\n",
    "        results = np.load(f\"{resultsFolder}{dataset}_{model}-static_layer1_1.npz\")\n",
    "        r2_emb_pos_m = results[\"out_of_sample_r2\"].mean()\n",
    "\n",
    "        for i in range(layer_range[0], layer_range[1]+1):\n",
    "            results = np.load(f\"{resultsFolder}{dataset}_{model}_{layer_name}{i}_1.npz\")\n",
    "            r2_layer.append(results[\"out_of_sample_r2\"].mean())\n",
    "        \n",
    "        if r2_emb_pos_m is not None:\n",
    "            plt.axhline(r2_emb_pos_m.mean(), linestyle='--', color=c[counter], label=model)\n",
    "        plt.plot(r2_layer, marker='o', color=c[counter])\n",
    "        plt.axvline(bl, color=c[counter], linestyle='--')\n",
    "        counter += 1\n",
    "        \n",
    "    plt.xlabel(\"Layer number\", fontsize=40)\n",
    "    plt.ylabel('R2' + r\"$_{oos}$\", fontsize=40)\n",
    "    plt.xticks(fontsize=30) \n",
    "    plt.yticks(yticks, fontsize=30) \n",
    "    plt.legend()\n",
    "    plt.legend(fontsize=20)\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_test_perf_across_layers(['gpt2-large', 'roberta-large'], \n",
    "                             dataset='fedorenko', layers_range=[[0,36], [0,24]],\n",
    "                             layer_name_arr=['layer_', 'layer_'], best_layer=[19, 12], \n",
    "                             saveName='perf-across-layers-trained-fed', \n",
    "                             figurePath=figurePath, resultsFolder=resultsFolder_fed, \n",
    "                             c=['orange', 'blue'], yticks=[0, 0.015, 0.03, 0.045, 0.06])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "results_pos = np.load(f\"{resultsFolder_fed}{dataset}_positional_layer1_1_both.npz\")\n",
    "R2_pos = compute_R2_simple(results_pos, resultsFolder_fed, dataset=dataset, exp='both')\n",
    "\n",
    "results_static_grow = np.load(f\"{resultsFolder_fed}{dataset}_static_grow_layer1_1000_both.npz\")\n",
    "R2_static_grow = compute_R2_simple(results_static_grow, resultsFolder_fed, dataset=dataset, exp='both')\n",
    "\n",
    "\n",
    "results_pos_grow = np.load(f\"{resultsFolder_fed}{dataset}_pos+grow_layer1_1000_both.npz\")\n",
    "R2_pos_grow = compute_R2_simple(results_pos_grow, resultsFolder_fed, dataset=dataset, exp='both')\n",
    "\n",
    "results_static = np.load(f\"{resultsFolder_fed}{dataset}_gpt2-large-static_layer1_1_both.npz\")\n",
    "R2_static = compute_R2_simple(results_static, resultsFolder_fed, dataset=dataset, exp='both')\n",
    "\n",
    "results_static_pos = np.load(f\"{resultsFolder_fed}{dataset}_gpt2-large-static-pos_layer1_1_both.npz\")\n",
    "R2_static_pos = compute_R2_simple(results_static_pos, resultsFolder_fed, dataset=dataset, exp='both')\n",
    "\n",
    "results_gpt_pos_avg = np.load(f\"{resultsFolder_fed}{dataset}_gpt-pos-avg_layer1_1_both.npz\")\n",
    "R2_gpt_pos = compute_R2_simple(results_gpt_pos_avg , resultsFolder_fed, dataset=dataset, exp='both')\n",
    "\n",
    "results_all = np.load(f\"{resultsFolder_fed}{dataset}_static_bil_grow_layer1_1000_both.npz\")\n",
    "R2_all = compute_R2_simple(results_all , resultsFolder_fed, dataset=dataset, exp='both')\n",
    "\n",
    "find_best_layer('gpt2-large', resultsFolder_fed, required_str=['layer'], exclude_str=['hfgpt'])\n",
    "results_bil_gl = np.load(f\"{resultsFolder_fed}{dataset}_gpt2-large_layer_19_1_both.npz\")\n",
    "R2_bil_gl = compute_R2_simple(results_bil_gl, resultsFolder_fed, dataset=dataset, exp='both')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.mean(R2_pos_grow)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mse_interp = np.mean(results_static_grow['mse_stored'],axis=1)\n",
    "mse_bil = np.mean(results_bil_gl['mse_stored'], axis=1)\n",
    "bar_width = 0.35\n",
    "plt.bar(np.arange(13)-bar_width, mse_interp, width=bar_width, label='Static + Pos Grow')\n",
    "plt.bar(np.arange(13), mse_bil, width=bar_width, label='BIL')\n",
    "plt.legend()\n",
    "plt.ylabel('MSE', fontsize=18)\n",
    "plt.xlabel('Sentence group', fontsize=18)\n",
    "plt.ylim([0.2, 0.7])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "results_grow = np.load(f\"{resultsFolder_fed}{dataset}_bil_pos_grow_layer1_1000_both.npz\")\n",
    "R2_grow = compute_R2_simple(results_grow, resultsFolder_fed, dataset=dataset, exp='both')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.mean(R2_grow)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.mean(R2_static_grow)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.mean(results_static_grow['mse_stored'],axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.mean(R2_static_grow)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.mean(R2_bil_gl)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "y_true = np.load('/home3/name/what-is-brainscore/temp_data_all/y_federonko.npy')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# for a given electrode, each row represents the words\n",
    "# for a given list \n",
    "y_true_reshaped = np.reshape(y_true, (52, 8, 97))\n",
    "y_true_reshaped_avg_sent = np.mean(y_true_reshaped, axis=0)\n",
    "plt.figure(figsize=(15,10))\n",
    "plt.plot(y_true_reshaped_avg_sent, marker='o', alpha=0.4)\n",
    "plt.show()\n",
    "\n",
    "#plt.plot(np.mean(y_true_reshaped_avg_sent, axis=1), marker='o')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in range(52):\n",
    "    plt.plot(y_true[i*8:(i+1)*8, 0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "y_hat_bil = dict(results_bil_gl)['y_hat_folds']\n",
    "y_hat_pos = dict(results_pos)['y_hat_folds']\n",
    "y_static = dict(results_static)['y_hat_folds']\n",
    "mse_bil = mean_squared_error(y_true.T, y_hat_bil.T, multioutput='raw_values')\n",
    "mse_pos = mean_squared_error(y_true.T, y_hat_pos.T, multioutput='raw_values')\n",
    "mse_static = mean_squared_error(y_true.T, y_static.T, multioutput='raw_values')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_words_per_sent = 8\n",
    "mse_bil_per_position = []\n",
    "mse_pos_per_position = []\n",
    "mse_static_per_position = []\n",
    "for i in range(num_words_per_sent):\n",
    "    mse_bil_pos_i = mse_bil[i::num_words_per_sent]\n",
    "    mse_pos_pos_i = mse_pos[i::num_words_per_sent]\n",
    "    mse_static_pos_i = mse_static[i::num_words_per_sent]\n",
    "    mse_bil_per_position.append(np.mean(mse_bil_pos_i))\n",
    "    mse_pos_per_position.append(np.mean(mse_pos_pos_i))\n",
    "    mse_static_per_position.append(np.mean(mse_static_pos_i))\n",
    "    \n",
    "# Width of the bars\n",
    "bar_width = 0.20\n",
    "\n",
    "fig, ax = plt.subplots(figsize=(15, 10))\n",
    "\n",
    "# Plotting the bars\n",
    "ax.bar(np.arange(8) - bar_width, mse_bil_per_position, bar_width, label='MSE BIL')\n",
    "ax.bar(np.arange(8), mse_pos_per_position, bar_width, label='MSE POSITIONAL', color='r')\n",
    "ax.bar(np.arange(8) + bar_width, mse_static_per_position, bar_width, label='MSE STATIC', color='g')\n",
    "ax.set_xlabel('Position', fontsize=18)\n",
    "ax.set_ylabel(\"Average MSE\", fontsize=18)\n",
    "plt.legend()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_words_per_sent = 8\n",
    "mse_bil_per_position = []\n",
    "mse_pos_per_position = []\n",
    "mse_static_per_position = []\n",
    "for i in range(num_words_per_sent):\n",
    "    mse_bil_pos_i = mse_bil[i::num_words_per_sent]\n",
    "    mse_pos_pos_i = mse_pos[i::num_words_per_sent]\n",
    "    mse_static_pos_i = mse_static[i::num_words_per_sent]\n",
    "    mse_bil_per_position.append(np.mean(mse_bil_pos_i))\n",
    "    mse_pos_per_position.append(np.mean(mse_pos_pos_i))\n",
    "    mse_static_per_position.append(np.mean(mse_static_pos_i))\n",
    "    \n",
    "# Width of the bars\n",
    "bar_width = 0.20\n",
    "\n",
    "fig, ax = plt.subplots(figsize=(15, 10))\n",
    "\n",
    "# Plotting the bars\n",
    "ax.bar(np.arange(8) - bar_width, mse_bil_per_position, bar_width, label='MSE BIL')\n",
    "ax.bar(np.arange(8), mse_pos_per_position, bar_width, label='MSE POSITIONAL', color='r')\n",
    "ax.bar(np.arange(8) + bar_width, mse_static_per_position, bar_width, label='MSE STATIC', color='g')\n",
    "ax.set_xlabel('Position', fontsize=18)\n",
    "ax.set_ylabel(\"Average MSE\", fontsize=18)\n",
    "plt.legend()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_bil = dict(np.load('/home3/name/what-is-brainscore/temp_data_all/temp_data_federonko/X_gpt2-large.npz'))['layer_10']\n",
    "X_static = dict(np.load('/home3/name/what-is-brainscore/temp_data_all/temp_data_federonko/X_gpt2-large-static.npz'))['layer1']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_bil_norm = np.linalg.norm(X_static, axis=1)\n",
    "X_bil_norm_reshaped = np.reshape(X_bil_norm, (52, 8))\n",
    "plt.plot(np.mean(X_bil_norm_reshaped,axis=0))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_bil_norm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "llama",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.4"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
