{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "base = '/home3/name/what-is-brainscore/'\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from matplotlib import pyplot as plt\n",
    "import os\n",
    "from sklearn.metrics import mean_squared_error\n",
    "import sys\n",
    "sys.path.append(f'f{base}')\n",
    "from plotting_functions import plot_test_perf_across_layers, plot_across_subjects, save_fMRI_simple, pool_across_seeds, single_seed_mse_r2, find_rows_without_nan, pass_info_plot_hist2d\n",
    "from stats_funcs import compute_p_val, arrange_pvals_pd, max_across_nested, remove_neg_r2, modified_r2_and_idxs, mse_max_model\n",
    "from scipy.stats import pearsonr\n",
    "import seaborn as sns\n",
    "import pandas as pd\n",
    "import statsmodels.formula.api as smf\n",
    "import matplotlib\n",
    "from scipy.stats import ttest_rel, ttest_1samp\n",
    "import statsmodels.formula.api as smf\n",
    "import nibabel as nib\n",
    "from nilearn import plotting\n",
    "from nilearn import surface\n",
    "from nilearn import datasets\n",
    "import plotly\n",
    "import brainio\n",
    "from matplotlib import pyplot as plt\n",
    "from netCDF4 import Dataset\n",
    "\n",
    "from scipy.stats import wilcoxon"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_processed_folder = f'{base}data_processed/pereira/'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_name = 'gpt2-xl'\n",
    "N = 10 # number of random seeds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# load important things\n",
    "figurePath = '/home3/name/what-is-brainscore/analyze_results/figures_code/figures/pereira_untrained/'\n",
    "resultsFolder = f'{base}results_all/results_pereira/untrained/'\n",
    "nc_file_pereira = f'{base}/pereira_data/no_share/pereira_all.nc'\n",
    "pereira_data = brainio.assemblies.DataAssembly.from_files(nc_file_pereira)\n",
    "dataset = 'pereira'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "exp = ['both', '243', '384']\n",
    "br_labels_dict = {}\n",
    "num_vox_dict = {}\n",
    "ytest_dict = {}\n",
    "mse_intercept_dict = {}\n",
    "subjects_dict = {}\n",
    "for e in exp:\n",
    "    bre = np.load(f'{base}/pereira_data/networks_{e}.npy', allow_pickle=True)\n",
    "    br_labels_dict[e] = bre\n",
    "    num_vox_dict[e] = bre.shape[0]\n",
    "    mse_intercept_dict[e] = np.load(f'/home3/name/what-is-brainscore/results_all/results_pereira/mse_intercept_{e}.npy')\n",
    "    ytest_dict[e] = np.load(f'/home3/name/what-is-brainscore/results_all/results_pereira/y_test_ordered_{e}.npy')\n",
    "    subjects_dict[e] = np.load(f\"{data_processed_folder}/subjects_{e}.npy\", allow_pickle=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_best_layer(exp, model_name, N):\n",
    "    best_layer = []\n",
    "    best_layer_lang = []\n",
    "    for i in range(N):\n",
    "        keys = list(dict(np.load(f'/home3/name/what-is-brainscore/data_processed/pereira/X_{model_name}-untrained-sp-{exp}_m{i}.npz')).keys())\n",
    "        bil_all = [k for k in keys if 'all' in k][0]\n",
    "        bil_lang = [k for k in keys if 'lang' in k][0]\n",
    "        best_layer.append(bil_all)\n",
    "        best_layer_lang.append(bil_lang.replace('_lang', ''))\n",
    "        \n",
    "    return best_layer, best_layer_lang\n",
    "\n",
    "_, bl_lang_243 = load_best_layer('243', 'gpt2-xl', N)\n",
    "_, bl_lang_384 = load_best_layer('384', 'gpt2-xl', N)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# load mse value for BIL only model\n",
    "mse_bil_243, r2_bil_243 = pool_across_seeds(ytest_dict['243'],np.repeat('gpt2-xl-untrained-sp', N), np.repeat('243', N), \n",
    "                                              bl_lang_243, niters=np.repeat(1,N), resultsFolder=resultsFolder)\n",
    "mse_bil_384, r2_bil_384 = pool_across_seeds(ytest_dict['384'],np.repeat('gpt2-xl-untrained-sp', N), np.repeat('384', N), \n",
    "                                              bl_lang_384, niters=np.repeat(1,N), resultsFolder=resultsFolder)\n",
    "\n",
    "# load mse value for BIL only model\n",
    "mse_bil_PW_243, r2_bil_PW_243 = pool_across_seeds(ytest_dict['243'],np.repeat('gpt2-xl-ut_bil-lang_POSWN', N), np.repeat('243', N), \n",
    "                                              np.repeat('layer1', N), niters=np.repeat(1000,N), resultsFolder=resultsFolder, seed_last=False)\n",
    "mse_bil_PW_384, r2_bil_PW_384 = pool_across_seeds(ytest_dict['384'],np.repeat('gpt2-xl-ut_bil-lang_POSWN', N), np.repeat('384', N), \n",
    "                                              np.repeat('layer1', N), niters=np.repeat(1000,N), resultsFolder=resultsFolder, seed_last=False)\n",
    "\n",
    "# load mse value for BIL only model\n",
    "mse_bil_W_243, r2_bil_W_243 = pool_across_seeds(ytest_dict['243'],np.repeat('gpt2-xl-ut_bil-lang_WN', N), np.repeat('243', N), \n",
    "                                              np.repeat('layer1', N), niters=np.repeat(1000,N), resultsFolder=resultsFolder, seed_last=False)\n",
    "mse_bil_W_384, r2_bil_W_384 = pool_across_seeds(ytest_dict['384'],np.repeat('gpt2-xl-ut_bil-lang_WN', N), np.repeat('384', N), \n",
    "                                              np.repeat('layer1', N), niters=np.repeat(1000,N), resultsFolder=resultsFolder, seed_last=False)\n",
    "\n",
    "# load mse value for BIL only model\n",
    "mse_bil_P_243, r2_bil_P_243 = pool_across_seeds(ytest_dict['243'],np.repeat('gpt2-xl-ut_bil-lang_POS', N), np.repeat('243', N), \n",
    "                                              np.repeat('layer1', N), niters=np.repeat(1000,N), resultsFolder=resultsFolder, seed_last=False)\n",
    "mse_bil_P_384, r2_bil_P_384 = pool_across_seeds(ytest_dict['384'],np.repeat('gpt2-xl-ut_bil-lang_POS', N), np.repeat('384', N), \n",
    "                                              np.repeat('layer1', N), niters=np.repeat(1000,N), resultsFolder=resultsFolder, seed_last=False)\n",
    "\n",
    "# load non BIL models \n",
    "mse_PW_243, r2_PW_243 = single_seed_mse_r2(ytest_dict['243'], 'positional_WN', niters=1000, exp='243', layer_name='layer1', resultsFolder=resultsFolder)\n",
    "mse_PW_384, r2_PW_384 = single_seed_mse_r2(ytest_dict['384'], 'positional_WN', niters=1000, exp='384', layer_name='layer1', resultsFolder=resultsFolder)\n",
    "\n",
    "mse_P_243, r2_P_243 = single_seed_mse_r2(ytest_dict['243'], 'positional_simple', niters=1, exp='243', layer_name='layer1', resultsFolder=resultsFolder)\n",
    "mse_P_384, r2_P_384 = single_seed_mse_r2(ytest_dict['384'], 'positional_simple', niters=1, exp='384', layer_name='layer1', resultsFolder=resultsFolder)\n",
    "\n",
    "mse_W_243, r2_W_243 = single_seed_mse_r2(ytest_dict['243'], 'word-num', niters=1, exp='243', layer_name='layer1', resultsFolder=resultsFolder)\n",
    "mse_W_384, r2_W_384 = single_seed_mse_r2(ytest_dict['384'], 'word-num', niters=1, exp='384', layer_name='layer1', resultsFolder=resultsFolder)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "non_nan_384 = find_rows_without_nan(np.vstack((r2_W_384, r2_P_384, r2_PW_384, r2_bil_384, r2_bil_W_384, r2_bil_P_384, r2_bil_PW_384)))\n",
    "non_nan_243 = find_rows_without_nan(np.vstack((r2_W_243, r2_P_243, r2_PW_243, r2_bil_243, r2_bil_W_243, r2_bil_P_243, r2_bil_PW_243)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "N = 10\n",
    "yticks=[]\n",
    "networks = ['language', 'DMN', 'MD', 'visual']\n",
    "model_names = ['gpt2-xl-untrained-sp' for i in range(N)]\n",
    "layers_range = [[0,48] for i in range(N)]\n",
    "layer_name_arr = [f'layer_' for i in range(0,48)]\n",
    "model_nums = np.arange(N)\n",
    "colors = {}\n",
    "colors_arr = ['tab:blue', 'tab:green', 'tab:red', 'tab:purple']\n",
    "for i, n in enumerate(networks):\n",
    "    colors[n] = colors_arr[i]\n",
    "plot_test_perf_across_layers(model_names, dataset='pereira', layers_range=layers_range,\n",
    "                             layer_name_arr=layer_name_arr, saveName='gpt2l-ut_perf_across_layers_384', \n",
    "                             figurePath=figurePath, resultsFolder=resultsFolder, yticks=[0, .017], exp='384', \n",
    "                             model_nums=model_nums,\n",
    "                             networks=networks, br_labels=br_labels_dict['384'], \n",
    "                             subjects=subjects_dict['384'], colors=colors, plot_legend=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "plot_test_perf_across_layers(model_names, dataset='pereira', layers_range=layers_range,\n",
    "                             layer_name_arr=layer_name_arr, saveName='gpt2l-ut_perf_across_layers_243', \n",
    "                             figurePath=figurePath, resultsFolder=resultsFolder, yticks=[0, 0.015, 0.03], exp='243', \n",
    "                             model_nums=model_nums,\n",
    "                             networks=networks, br_labels=br_labels_dict['243'], \n",
    "                             subjects=subjects_dict['243'], colors=colors, plot_legend=False)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# store r2 valeus into a pd dataframe\n",
    "model_names = ['SL', 'SP', 'SP+SL', 'GPT2-XLU', 'GPT2-XLU+SL', 'GPT2-XLU+SP', 'GPT2-XLU+SP+SL']\n",
    "best_DEM_model = model_names[2]\n",
    "best_LLM_model = model_names[-1]\n",
    "LLM_name = model_names[3]\n",
    "num_models = len(model_names)\n",
    "voxels_ids_243 = np.tile(np.arange(num_vox_dict['243']), num_models)\n",
    "br_labels_243 = np.tile(br_labels_dict['243'], num_models)\n",
    "model_order_243 = np.repeat(np.array(model_names), num_vox_dict['243'])\n",
    "\n",
    "r2_stacked_243 = np.concatenate((r2_W_243, r2_P_243, r2_PW_243, r2_bil_243, r2_bil_W_243, r2_bil_P_243, r2_bil_PW_243))\n",
    "r2_stacked_pd_243 = pd.DataFrame({'Model': model_order_243, 'r2':r2_stacked_243, 'voxel_id': voxels_ids_243, \n",
    "                                  'Network': br_labels_243, 'subjects': np.tile(subjects_dict['243'], num_models)}).dropna()\n",
    "\n",
    "\n",
    "voxels_ids_384 = np.tile(np.arange(num_vox_dict['384']), num_models)\n",
    "br_labels_384 = np.tile(br_labels_dict['384'], num_models)\n",
    "model_order_384 = np.repeat(np.array(model_names), num_vox_dict['384'])\n",
    "r2_stacked_384 = np.concatenate((r2_W_384, r2_P_384, r2_PW_384, r2_bil_384, r2_bil_W_384, r2_bil_P_384, r2_bil_PW_384))\n",
    "\n",
    "r2_stacked_pd_384 = pd.DataFrame({'Model': model_order_384, 'r2':r2_stacked_384, 'voxel_id': voxels_ids_384, \n",
    "                                   'Network': br_labels_384, \n",
    "                                   'subjects': np.tile(subjects_dict['384'], num_models)}).dropna()\n",
    "\n",
    "\n",
    "network_order=['language', 'DMN', 'MD', 'visual']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "subj_avg_pd_243, _, _ = plot_across_subjects(r2_stacked_pd_243.copy(), selected_networks=['language'],\n",
    "                                             figurePath=figurePath, \n",
    "                                             saveName='243_across_subjects_xl_untrained', hue_order=model_names, yticks=[0,0.08], \n",
    "                                                order=['language'], clip_zero=True, plot_legend=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "subj_avg_pd_243_ri = subj_avg_pd_243.reset_index()\n",
    "subj_avg_pd_243_ri = subj_avg_pd_243_ri.drop(columns=['voxel_id'])\n",
    "subj_243_lang = subj_avg_pd_243_ri.loc[subj_avg_pd_243_ri.Network=='language']\n",
    "filtered_df = subj_243_lang[(subj_243_lang['Model'].isin([best_DEM_model, best_LLM_model]))]\n",
    "filtered_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "subj_avg_pd_384, _, _ = plot_across_subjects(r2_stacked_pd_384.copy(), selected_networks=['language'],\n",
    "                                             figurePath=figurePath, \n",
    "                                             saveName='384_across_subjects_xl_untrained', hue_order=model_names, yticks=[0,0.05], \n",
    "                                                order=['language'], clip_zero=True, plot_legend=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "subj_avg_pd_384_ri = subj_avg_pd_384.reset_index()\n",
    "subj_avg_pd_384_ri = subj_avg_pd_384_ri.drop(columns=['voxel_id'])\n",
    "subj_384_lang = subj_avg_pd_384_ri.loc[subj_avg_pd_384_ri.Network=='language']\n",
    "filtered_df = subj_384_lang[(subj_384_lang['Model'].isin([best_DEM_model, best_LLM_model]))]\n",
    "filtered_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pass_info_plot_hist2d(max_val_dict={'language': 0.15, 'DMN': 0.15, 'MD':0.15, 'visual': 0.15}, min_val=-0.03,\n",
    "                      df=r2_stacked_pd_384, best_DEM_model=best_DEM_model, best_LLM_model=best_LLM_model, \n",
    "                      figurePath=figurePath, saveName='scatter_384_xl_ut')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pass_info_plot_hist2d(max_val_dict={'language': 0.20, 'DMN': 0.20, 'MD':0.20, 'visual': 0.20}, \n",
    "                      min_val=-0.04, df=r2_stacked_pd_243, best_DEM_model=best_DEM_model, best_LLM_model=best_LLM_model, \n",
    "                      figurePath=figurePath, saveName='scatter_243_xl_ut')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pvals_243_BLPW_PW = compute_p_val('243', num_vox_dict, mse_bil_PW_243, mse_PW_243)\n",
    "pvals_384_BLPW_PW = compute_p_val('384', num_vox_dict, mse_bil_PW_384, mse_PW_384)\n",
    "pvals_243_BLPW_PW = pvals_243_BLPW_PW[non_nan_243]\n",
    "pvals_384_BLPW_PW = pvals_384_BLPW_PW[non_nan_384]\n",
    "pvals_pd_243 = arrange_pvals_pd(pvals_243_BLPW_PW, '243', subjects_dict, br_labels_dict, non_nan_243)\n",
    "pvals_pd_384 = arrange_pvals_pd(pvals_384_BLPW_PW, '384', subjects_dict, br_labels_dict, non_nan_384)\n",
    "psig_384_before_fdr = pvals_pd_384.loc[pvals_pd_384.pvals < 0.05].shape[0]/len(pvals_pd_384) * 100\n",
    "psig_384_after_fdr = pvals_pd_384.loc[pvals_pd_384.pvals_adj < 0.05].shape[0]/len(pvals_pd_384) * 100\n",
    "psig_243_before_fdr = pvals_pd_243.loc[pvals_pd_243.pvals < 0.05].shape[0]/len(pvals_pd_243) * 100\n",
    "psig_243_after_fdr = pvals_pd_243.loc[pvals_pd_243.pvals_adj < 0.05].shape[0]/len(pvals_pd_243) * 100\n",
    "print(\"384\", psig_384_before_fdr, psig_384_after_fdr)\n",
    "print(\"243\", psig_243_before_fdr, psig_243_after_fdr)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pvals_243_BLPW_PW = compute_p_val('243', num_vox_dict, mse_bil_PW_243_concat.T, mse_PW_243_concat.T)\n",
    "pvals_384_BLPW_PW = compute_p_val('384', num_vox_dict, mse_bil_PW_384_concat.T, mse_PW_384_concat.T)\n",
    "pvals_pd_243 = arrange_pvals_pd(pvals_243_BLPW_PW, '243', subjects_dict, br_labels_dict, non_nan_243)\n",
    "pvals_pd_384 = arrange_pvals_pd(pvals_384_BLPW_PW, '384', subjects_dict, br_labels_dict, non_nan_384)\n",
    "pvals_pd_384 = pvals_pd_384.loc[pvals_pd_384.network=='language']\n",
    "pvals_pd_243 = pvals_pd_243.loc[pvals_pd_243.network=='language']\n",
    "psig_384_before_fdr = pvals_pd_384.loc[pvals_pd_384.pvals < 0.05].shape[0]/len(pvals_pd_384) * 100\n",
    "psig_384_after_fdr = pvals_pd_384.loc[pvals_pd_384.pvals_adj < 0.05].shape[0]/len(pvals_pd_384) * 100\n",
    "psig_243_before_fdr = pvals_pd_243.loc[pvals_pd_243.pvals < 0.05].shape[0]/len(pvals_pd_243) * 100\n",
    "psig_243_after_fdr = pvals_pd_243.loc[pvals_pd_243.pvals_adj < 0.05].shape[0]/len(pvals_pd_243) * 100\n",
    "print(\"384\", psig_384_before_fdr, psig_384_after_fdr)\n",
    "print(\"243\", psig_243_before_fdr, psig_243_after_fdr)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Voxel-wise corrections"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "BL_mod_384, non_BL_mod_384, full_mod_384 = \\\n",
    "            modified_r2_and_idxs(r2_stacked_pd_384, LLM_name, best_DEM_model, best_LLM_model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "BL_mod_243, non_BL_mod_243, full_mod_243 = \\\n",
    "            modified_r2_and_idxs(r2_stacked_pd_243, LLM_name, best_DEM_model, best_LLM_model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mse_PW_384_concat = mse_max_model(np.stack((mse_W_384, mse_P_384, mse_PW_384))[:, :, non_nan_384])\n",
    "mse_bil_PW_384_concat = mse_max_model(np.stack((mse_bil_384, mse_bil_W_384, mse_bil_P_384, mse_bil_PW_384))[:, :, non_nan_384])\n",
    "\n",
    "mse_PW_243_concat = mse_max_model(np.stack((mse_W_243, mse_P_243, mse_PW_243))[:, :, non_nan_243])\n",
    "mse_bil_PW_243_concat = mse_max_model(np.stack((mse_bil_243, mse_bil_W_243, mse_bil_P_243, mse_bil_PW_243))[:, :, non_nan_243])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mse_bil_PW_384_concat.T.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pvals_243_BLPW_PW = compute_p_val('243', num_vox_dict, mse_bil_PW_243_concat.T, mse_PW_243_concat.T)\n",
    "pvals_384_BLPW_PW = compute_p_val('384', num_vox_dict, mse_bil_PW_384_concat.T, mse_PW_384_concat.T)\n",
    "pvals_pd_243 = arrange_pvals_pd(pvals_243_BLPW_PW, '243', subjects_dict, br_labels_dict, non_nan_243)\n",
    "pvals_pd_384 = arrange_pvals_pd(pvals_384_BLPW_PW, '384', subjects_dict, br_labels_dict, non_nan_384)\n",
    "psig_384_before_fdr = pvals_pd_384.loc[pvals_pd_384.pvals < 0.05].shape[0]/len(pvals_pd_384) * 100\n",
    "psig_384_after_fdr = pvals_pd_384.loc[pvals_pd_384.pvals_adj < 0.05].shape[0]/len(pvals_pd_384) * 100\n",
    "psig_243_before_fdr = pvals_pd_243.loc[pvals_pd_243.pvals < 0.05].shape[0]/len(pvals_pd_243) * 100\n",
    "psig_243_after_fdr = pvals_pd_243.loc[pvals_pd_243.pvals_adj < 0.05].shape[0]/len(pvals_pd_243) * 100\n",
    "print(\"384\", psig_384_before_fdr, psig_384_after_fdr)\n",
    "print(\"243\", psig_243_before_fdr, psig_243_after_fdr)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pvals_pd_243.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pvals_243_BLPW_PW = compute_p_val('243', num_vox_dict, mse_bil_PW_243_concat.T, mse_PW_243_concat.T)\n",
    "pvals_384_BLPW_PW = compute_p_val('384', num_vox_dict, mse_bil_PW_384_concat.T, mse_PW_384_concat.T)\n",
    "pvals_pd_243 = arrange_pvals_pd(pvals_243_BLPW_PW, '243', subjects_dict, br_labels_dict, non_nan_243)\n",
    "pvals_pd_384 = arrange_pvals_pd(pvals_384_BLPW_PW, '384', subjects_dict, br_labels_dict, non_nan_384)\n",
    "pvals_pd_384 = pvals_pd_384.loc[pvals_pd_384.network=='language']\n",
    "pvals_pd_243 = pvals_pd_243.loc[pvals_pd_243.network=='language']\n",
    "psig_384_before_fdr = pvals_pd_384.loc[pvals_pd_384.pvals < 0.05].shape[0]/len(pvals_pd_384) * 100\n",
    "psig_384_after_fdr = pvals_pd_384.loc[pvals_pd_384.pvals_adj < 0.05].shape[0]/len(pvals_pd_384) * 100\n",
    "psig_243_before_fdr = pvals_pd_243.loc[pvals_pd_243.pvals < 0.05].shape[0]/len(pvals_pd_243) * 100\n",
    "psig_243_after_fdr = pvals_pd_243.loc[pvals_pd_243.pvals_adj < 0.05].shape[0]/len(pvals_pd_243) * 100\n",
    "print(\"384\", psig_384_before_fdr, psig_384_after_fdr)\n",
    "print(\"243\", psig_243_before_fdr, psig_243_after_fdr)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "modified_384 = pd.concat((non_BL_mod_384, full_mod_384))\n",
    "modified_243 = pd.concat((non_BL_mod_243, full_mod_243))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "default_palette = sns.color_palette(\"deep\")\n",
    "default_palette"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "_, _, _ = plot_across_subjects(modified_384.copy(), figurePath=figurePath, selected_networks=['language'],\n",
    "                                             saveName='384_across_subjects_xl_untrained_mod', hue_order=[best_DEM_model, best_LLM_model], \n",
    "                                             yticks=[0,0.05], \n",
    "                                                order=['language'], clip_zero=True, color_palette=[default_palette[2], default_palette[6]], draw_lines=True, ms=15)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "_, _, _ = plot_across_subjects(modified_243.copy(), figurePath=figurePath, selected_networks=['language'],\n",
    "                                             saveName='243_across_subjects_xl_untrained_mod', hue_order=[best_DEM_model, best_LLM_model], \n",
    "                                             yticks=[0,0.08], \n",
    "                                                order=['language'], clip_zero=True, color_palette=[default_palette[2], default_palette[6]], draw_lines=True, ms=15)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#pass_info_plot_hist2d(max_val_dict={'language': 0.12, 'DMN': 0.06, 'MD':0.10, 'visual': 0.08}, min_val=-0.03, df=modified_384, best_DEM_model=best_DEM_model, best_LLM_model=best_LLM_model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#pass_info_plot_hist2d(max_val_dict={'language': 0.20, 'DMN': 0.20, 'MD':0.20, 'visual': 0.20}, min_val=-0.03, df=modified_243, best_DEM_model=best_DEM_model, best_LLM_model=best_LLM_model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "PW_corrected_r2_243 = modified_243.loc[modified_243.Model=='SP+SL'].r2.values\n",
    "BLPW_corrected_r2_243 = modified_243.loc[modified_243.Model=='GPT2-XLU+SP+SL'].r2.values"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plotting_folder = \"/data/LLMs/Pereira/plotting_data/\"\n",
    "subjects, stored_data_exp_WN_pos = save_fMRI_simple(PW_corrected_r2_243, exp='243', \n",
    "                        subjects_to_plot=np.unique(subjects_dict['243']), \n",
    "                        subjects_all=subjects_dict['243'], save_name='SP+SL_243') \n",
    "\n",
    "subjects, stored_data_exp_WN_pos = save_fMRI_simple(BLPW_corrected_r2_243, exp='243', \n",
    "                        subjects_to_plot=np.unique(subjects_dict['243']), \n",
    "                        subjects_all=subjects_dict['243'], save_name='GPT2-XLU+SP+SL_243')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plotting.plot_glass_brain(f'{plotting_folder}SP+SL_243_426.nii', \n",
    "                            colorbar=False, display_mode='xz', vmax=0.3, output_file=f'{figurePath}glass_brain_SP+SL_243_426.pdf')\n",
    "plotting.plot_glass_brain(f'{plotting_folder}GPT2-XLU+SP+SL_426.nii', \n",
    "                            colorbar=True, display_mode='xz', vmax=0.3, output_file=f'{figurePath}glass_brain_GPT2-XLU+SP+SL_243_426.pdf')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "PW_corrected_r2_384 = np.clip(modified_384.loc[modified_384.Model=='SP+SL'].r2.values, 0, np.inf)\n",
    "BLPW_corrected_r2_384 = np.clip(modified_384.loc[modified_384.Model=='GPT2-XLU+SP+SL'].r2.values, 0, np.inf)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "modified_384.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "modified_384.groupby(['subjects', 'Model']).r2.max()\n",
    "modified_384.r2.max()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plotting_folder = \"/data/LLMs/Pereira/plotting_data/\"\n",
    "subjects, stored_data_exp_WN_pos = save_fMRI_simple(PW_corrected_r2_384, exp='384', \n",
    "                        subjects_to_plot=np.unique(subjects_dict['384']), \n",
    "                        subjects_all=subjects_dict['384'][non_nan_384].squeeze(), save_name='SP+SL_384')\n",
    "\n",
    "subjects, stored_data_exp_WN_pos = save_fMRI_simple(BLPW_corrected_r2_384, exp='384', \n",
    "                        subjects_to_plot=np.unique(subjects_dict['384'][non_nan_384].squeeze()), \n",
    "                        subjects_all=subjects_dict['384'][non_nan_384].squeeze(), save_name='GPT2-XLU+SP+SL_384')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plotting.plot_glass_brain(f'{plotting_folder}SP+SL_384_343.nii', \n",
    "                            colorbar=False, display_mode='xz',\n",
    "                            vmax=0.20, output_file=f'{figurePath}glass_brain_SP+SL_384_343.pdf')\n",
    "plotting.plot_glass_brain(f'{plotting_folder}GPT2-XLU+SP+SL_384_343.nii', \n",
    "                            colorbar=True, display_mode='xz',\n",
    "                            vmax=0.20)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "output_file=f'{figurePath}glass_brain_GPT2-XLU+SP+SL_384_343.pdf'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "llama",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.4"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
