{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "base = '/home3/name/what-is-brainscore/'\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "base = '/home3/name/what-is-brainscore/'\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "import numpy as np\n",
    "base = '/home3/name/what-is-brainscore/'\n",
    "from matplotlib import pyplot as plt\n",
    "import os\n",
    "from sklearn.metrics import mean_squared_error\n",
    "import sys\n",
    "sys.path.append('/home3/name/what-is-brainscore/')\n",
    "from helper_funcs import find_best_layer\n",
    "from plotting_functions import plot_across_subjects, plot_test_perf_across_layers, plot_hist2d, save_fMRI_simple, compute_stats_and_var_exp, pass_info_plot_hist2d, find_rows_without_nan\n",
    "from stats_funcs import compute_p_val, arrange_pvals_pd, max_across_nested, remove_neg_r2, modified_r2_and_idxs, mse_max_model\n",
    "from scipy.stats import pearsonr\n",
    "import seaborn as sns\n",
    "import pandas as pd\n",
    "import statsmodels.formula.api as smf\n",
    "import matplotlib\n",
    "from scipy.stats import ttest_rel, ttest_1samp\n",
    "import statsmodels.formula.api as smf\n",
    "import nibabel as nib\n",
    "from nilearn import plotting\n",
    "from nilearn import surface\n",
    "from nilearn import datasets\n",
    "import plotly\n",
    "import brainio"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# load important things\n",
    "base = '/home3/name/what-is-brainscore/'\n",
    "figurePath = '/home3/name/what-is-brainscore/analyze_results/figures_code/figures/pereira_trained/'\n",
    "resultsFolder = f'{base}results_all/results_pereira/'\n",
    "nc_file_pereira = '/home3/name/what-is-brainscore/pereira_data/no_share/Pereira_data.nc'\n",
    "pereira_data = brainio.assemblies.DataAssembly.from_files(nc_file_pereira)\n",
    "nii_file_path_base = \"/home3/name/neural-nlp-exact/neural_nlp/analyze/surface_projection/\"\n",
    "data_processed_folder = f'{base}data_processed/pereira/'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "exp = ['both', '243', '384']\n",
    "br_labels_dict = {}\n",
    "num_vox_dict = {}\n",
    "ytest_dict = {}\n",
    "mse_intercept_dict = {}\n",
    "subjects_dict = {}\n",
    "for e in exp:\n",
    "    bre = np.load(f'{base}/pereira_data/networks_{e}.npy', allow_pickle=True)\n",
    "    br_labels_dict[e] = bre\n",
    "    num_vox_dict[e] = bre.shape[0]\n",
    "    mse_intercept_dict[e] = np.load(f'/home3/name/what-is-brainscore/results_all/results_pereira/mse_intercept_{e}.npy')\n",
    "    ytest_dict[e] = np.load(f'/home3/name/what-is-brainscore/results_all/results_pereira/y_test_ordered_{e}.npy')\n",
    "    subjects_dict[e] = np.load(f\"{data_processed_folder}/subjects_{e}.npy\", allow_pickle=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "across_layer_results_384 = np.load(\"/data/LLMs/Pereira/trained_results/GPT2SP_TEST_R2_ACROSS_LAYERS_384.npy\")\n",
    "across_layer_results_243 = np.load(\"/data/LLMs/Pereira/trained_results/GPT2SP_TEST_R2_ACROSS_LAYERS_243.npy\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "N = 1\n",
    "yticks=[]\n",
    "networks = ['language', 'DMN', 'MD', 'visual']\n",
    "model_names = ['gpt2-xl-sp' for i in range(N)]\n",
    "layers_range = [[0,48] for i in range(N)]\n",
    "layer_name_arr = [f'layer_' for i in range(0,48)]\n",
    "model_nums = [None]\n",
    "colors = []\n",
    "colors = {}\n",
    "colors_arr = ['tab:blue', 'tab:green', 'tab:red', 'tab:purple']\n",
    "for i, n in enumerate(networks):\n",
    "    colors[n] = colors_arr[i]\n",
    "res_pd = plot_test_perf_across_layers(model_names, dataset='pereira', layers_range=layers_range,\n",
    "                             layer_name_arr=layer_name_arr, saveName='gpt2-xl_perf_across_layers_243', \n",
    "                             figurePath=figurePath, resultsFolder=resultsFolder, yticks=[0, 0.03, 0.06], exp='243', \n",
    "                             model_nums=model_nums,\n",
    "                             networks=networks, br_labels=br_labels_dict['243'], \n",
    "                             subjects=subjects_dict['243'], colors=colors, plot_legend=True)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "N = 1\n",
    "yticks=[]\n",
    "networks = ['language']\n",
    "model_names = ['gpt2-xl-sp' for i in range(N)]\n",
    "layers_range = [[0,48] for i in range(N)]\n",
    "layer_name_arr = [f'layer_' for i in range(0,48)]\n",
    "model_nums = [None]\n",
    "colors = {}\n",
    "colors_arr = ['tab:blue', 'tab:green', 'tab:red', 'tab:purple']\n",
    "for i, n in enumerate(networks):\n",
    "    colors[n] = colors_arr[i]\n",
    "res_pd = plot_test_perf_across_layers(model_names, dataset='pereira', layers_range=layers_range,\n",
    "                             layer_name_arr=layer_name_arr, saveName='gpt2-xl_perf_across_layers_384', \n",
    "                             figurePath=figurePath, resultsFolder=resultsFolder, yticks=[0, 0.03, 0.06], exp='384', \n",
    "                             model_nums=model_nums,\n",
    "                             networks=networks, br_labels=br_labels_dict['384'], \n",
    "                             subjects=subjects_dict['384'], colors=colors_arr[0], plot_legend=False)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def find_best(df, model_to_keep, models_to_discard):\n",
    "    \n",
    "    for md in models_to_discard:\n",
    "        df =  df.loc[~df.model_name.str.contains(md)]\n",
    "    if len(model_to_keep) > 0:\n",
    "        df = df.loc[df.model_name.str.contains(model_to_keep)]\n",
    "    \n",
    "    best_model = df.loc[df['r2_vals'].idxmax()].model_name\n",
    "    return best_model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "all_models_results_384 = np.load(\"/data/LLMs/Pereira/trained_results/r2_384_trained.npy\")\n",
    "all_models_results_243 = np.load(\"/data/LLMs/Pereira/trained_results/r2_243_trained.npy\")\n",
    "model_names = np.load(\"/data/LLMs/Pereira/trained_results/r2_384_trained_model_names.npy\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "lang_voxels_384 = np.argwhere(br_labels_dict['384']=='language').squeeze()\n",
    "all_models_results_384[:, lang_voxels_384].shape\n",
    "model_r2_384 = pd.DataFrame({'model_name': model_names, 'r2_vals': np.nanmean(np.clip(all_models_results_384[:, lang_voxels_384], 0, np.inf), axis=1)})\n",
    "POS_SL_384 = find_best(model_r2_384, '', ['WORD', 'SENSE', 'SYNT', 'GPT2-XL'])\n",
    "word_384 = find_best(model_r2_384, 'WORD', ['SENSE', 'SYNT', 'GPT2-XL'])\n",
    "sense_384 = find_best(model_r2_384, 'SENSE', ['SYNT', 'GPT2-XL'])\n",
    "synt_384 = find_best(model_r2_384, 'SYNT', ['GPT2-XL'])\n",
    "GPT2_384 = find_best(model_r2_384, 'GPT2-XL', [])\n",
    "\n",
    "best_models_384 = ['GPT2-XL', POS_SL_384, word_384, sense_384, synt_384, GPT2_384]\n",
    "\n",
    "lang_voxels_243 = np.argwhere(br_labels_dict['243']=='language').squeeze()\n",
    "all_models_results_243[:, lang_voxels_243].shape\n",
    "model_r2_243 = pd.DataFrame({'model_name': model_names, 'r2_vals': np.nanmean(np.clip(all_models_results_243[:, lang_voxels_243], 0, np.inf), axis=1)})\n",
    "POS_SL_243 = find_best(model_r2_243, '', ['WORD', 'SENSE', 'SYNT', 'GPT2-XL'])\n",
    "word_243 = find_best(model_r2_243, 'WORD', ['SENSE', 'SYNT', 'GPT2-XL'])\n",
    "sense_243 = find_best(model_r2_243, 'SENSE', ['SYNT', 'GPT2-XL'])\n",
    "synt_243 = find_best(model_r2_243, 'SYNT', ['GPT2-XL'])\n",
    "GPT2_243 = find_best(model_r2_243, 'GPT2-XL', [])\n",
    "\n",
    "best_models_243 = ['GPT2-XL', POS_SL_243, word_243, sense_243, synt_243, GPT2_243]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "best_idxs_243 = [np.argwhere(model_names==x)[0][0] for x in best_models_243]\n",
    "best_idxs_384 = [np.argwhere(model_names==x)[0][0] for x in best_models_384]\n",
    "r2_best_384 = all_models_results_384[best_idxs_384]\n",
    "r2_best_243 = all_models_results_243[best_idxs_243]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "non_nan_384 = find_rows_without_nan(np.vstack((r2_best_384)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ordered_submodels = ['GPT2-XL', 'SP+SL', 'SP+SL+WORD', 'SP+SL+WORD+SENSE', 'SP+SL+WORD+SENSE+SYNT', 'SP+SL+WORD+SENSE+SYNT+GPT2-XL']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# store r2 valeus into a pd dataframe\n",
    "num_models = 6\n",
    "voxels_ids_243 = np.tile(np.arange(num_vox_dict['243']), num_models)\n",
    "br_labels_243 = np.tile(br_labels_dict['243'], num_models)\n",
    "model_order_243 = np.repeat(np.array(ordered_submodels), num_vox_dict['243'])\n",
    "\n",
    "r2_stacked_243 = np.hstack(r2_best_243)\n",
    "r2_stacked_pd_243 = pd.DataFrame({'Model': model_order_243, 'r2':r2_stacked_243, 'voxel_id': voxels_ids_243, \n",
    "                                  'Network': br_labels_243, 'subjects': np.tile(subjects_dict['243'], num_models)}).dropna()\n",
    "\n",
    "\n",
    "\n",
    "voxels_ids_384 = np.tile(np.arange(num_vox_dict['384']), num_models)\n",
    "br_labels_384 = np.tile(br_labels_dict['384'], num_models)\n",
    "model_order_384 = np.repeat(np.array(ordered_submodels), num_vox_dict['384'])\n",
    "\n",
    "r2_stacked_384 = np.hstack(r2_best_384)\n",
    "r2_stacked_pd_384 = pd.DataFrame({'Model': model_order_384, 'r2':r2_stacked_384, 'voxel_id': voxels_ids_384, \n",
    "                                   'Network': br_labels_384, \n",
    "                                   'subjects': np.tile(subjects_dict['384'], num_models)}).dropna()\n",
    "\n",
    "\n",
    "network_order=['language', 'DMN', 'MD', 'visual']\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "subj_avg_pd_384, _, _ = plot_across_subjects(r2_stacked_pd_384.copy(), figurePath=figurePath, saveName='384_across_subjects_xl', \n",
    "                                             hue_order=ordered_submodels[1:], yticks=[0, 0.08], \n",
    "                               draw_lines=False, selected_networks=['language'], order=['language'], plot_legend=False, gpt2_perf= 0.031517)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(subj_avg_pd_384.groupby(['Model']).r2.mean())\n",
    "print(subj_avg_pd_384.groupby(['Model']).r2.std())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "subj_avg_pd_243, _, _ = plot_across_subjects(r2_stacked_pd_243.copy(), figurePath=figurePath, saveName='243_across_subjects_xl', hue_order=ordered_submodels[1:], yticks=[0, 0.10], \n",
    "                               draw_lines=False, selected_networks=['language'], order=['language'], plot_legend=True, gpt2_perf=0.036352)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(subj_avg_pd_243.groupby(['Model']).r2.mean())\n",
    "print(subj_avg_pd_243.groupby(['Model']).r2.std())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "subj_avg_pd_243_ri = subj_avg_pd_243.reset_index().drop(columns=['voxel_id']) \n",
    "SP_SL_store = []\n",
    "SP_SL_WORD_store = []\n",
    "SP_SL_WORD_SENSE_store = []\n",
    "SP_SL_WORD_SENSE_SYNT_store = []\n",
    "mode = 'GPT2'\n",
    "for s in np.unique(subj_avg_pd_243_ri.subjects):\n",
    "    subj_243_pd = subj_avg_pd_243_ri.loc[subj_avg_pd_243_ri.subjects==s]\n",
    "    mean_subj_243_pd = subj_243_pd.groupby(['Model']).r2.mean().reset_index()\n",
    "    SP_SL = mean_subj_243_pd.loc[mean_subj_243_pd.Model=='SP+SL'].reset_index().r2.values[0]\n",
    "    SP_SL_WORD = mean_subj_243_pd.loc[mean_subj_243_pd.Model=='SP+SL+WORD'].reset_index().r2.values[0]\n",
    "    SP_SL_WORD_SENSE = mean_subj_243_pd.loc[mean_subj_243_pd.Model=='SP+SL+WORD+SENSE'].reset_index().r2.values[0]\n",
    "    SP_SL_WORD_SENSE_SYNT = mean_subj_243_pd.loc[mean_subj_243_pd.Model=='SP+SL+WORD+SENSE+SYNT'].reset_index().r2.values[0]\n",
    "    all_models = mean_subj_243_pd.loc[mean_subj_243_pd.Model=='SP+SL+WORD+SENSE+SYNT+GPT2-XL'].reset_index().r2.values[0]\n",
    "    GPT2 = mean_subj_243_pd.loc[mean_subj_243_pd.Model=='GPT2-XL'].reset_index().r2.values[0]\n",
    "    \n",
    "    if mode == 'all':\n",
    "        denom = all_models\n",
    "    if mode == 'GPT2':\n",
    "        denom = GPT2\n",
    "    \n",
    "    SP_SL_store.append(SP_SL/denom)\n",
    "    SP_SL_WORD_store.append(SP_SL_WORD/denom)\n",
    "    SP_SL_WORD_SENSE_store.append(SP_SL_WORD_SENSE/denom)\n",
    "    SP_SL_WORD_SENSE_SYNT_store.append(SP_SL_WORD_SENSE_SYNT/denom)\n",
    "    \n",
    "print(np.mean(SP_SL_store), np.std(SP_SL_store))\n",
    "print(np.mean(SP_SL_WORD_store), np.std(SP_SL_WORD_store))\n",
    "print(np.mean(SP_SL_WORD_SENSE_store), np.std(SP_SL_WORD_SENSE_store))\n",
    "print(np.mean(SP_SL_WORD_SENSE_SYNT_store), np.std(SP_SL_WORD_SENSE_SYNT_store))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "best_DEM_model_384 = 'SL+WORD+SENSE+SYNT'\n",
    "best_DEM_model_243 = 'SL+WORD+SYNT'\n",
    "best_LLM_model_384 = 'SL+GPT2-XL'\n",
    "best_LLM_model_243 = 'SL+WORD+SYNT+GPT2-XL'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pass_info_plot_hist2d(max_val_dict={'language': 0.2, 'DMN': 0.2, 'MD':0.2, 'visual': 0.2}, min_val=-0.04, df=r2_stacked_pd_384, best_DEM_model='SP+SL+WORD', \n",
    "                      best_LLM_model='SP+SL+WORD+SENSE+SYNT+GPT2-XL', \n",
    "                      figurePath=figurePath, saveName='scatter_384_xl')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pass_info_plot_hist2d(max_val_dict={'language': 0.25, 'DMN': 0.25, 'MD':0.25, 'visual': 0.25}, min_val=-0.05, df=r2_stacked_pd_243, best_DEM_model='SP+SL+WORD', best_LLM_model='SP+SL+WORD+SENSE+SYNT+GPT2-XL', \n",
    "                      figurePath=figurePath, saveName='scatter_243_xl')"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Voxel-wise corrections"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "LLM_name = 'GPT2-XL'\n",
    "simple_model = 'SP+SL+WORD'\n",
    "full_model = 'SP+SL+WORD+SENSE+SYNTAX+GPT2-XL'\n",
    "remove_str_simple = \"GPT2-XL\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_names_vc = ['SP+SL', 'WORD', 'SP+SL+WORD', 'SP+WORD', 'SL+WORD']\n",
    "gpt2_models = [x for x in model_names if 'GPT2-XL' in x and 'SENSE' not in x and 'SYNT' not in x]\n",
    "model_names_vc.extend(gpt2_models)\n",
    "sp_sl_word_gpt_models_idxs = [np.argwhere(model_names==x)[0][0] for x in model_names_vc]\n",
    "r2_vc_384 = all_models_results_384[sp_sl_word_gpt_models_idxs]\n",
    "r2_vc_243 = all_models_results_243[sp_sl_word_gpt_models_idxs]\n",
    "num_models_vc = len(model_names_vc)\n",
    "\n",
    "vc_PD_384 = pd.DataFrame({'voxel_id': np.tile(np.arange(num_vox_dict['384']), num_models_vc), \n",
    "                          'r2':np.hstack(r2_vc_384), 'Model': np.repeat(model_names_vc, num_vox_dict['384']), \n",
    "                          'Network': np.tile(br_labels_dict['384'], num_models_vc), \n",
    "                          'subjects': np.tile(subjects_dict['384'], num_models_vc)}).dropna()\n",
    "\n",
    "vc_PD_243 = pd.DataFrame({'voxel_id': np.tile(np.arange(num_vox_dict['243']), num_models_vc), \n",
    "                          'r2':np.hstack(r2_vc_243), 'Model': np.repeat(model_names_vc, num_vox_dict['243']), \n",
    "                          'Network': np.tile(br_labels_dict['243'], num_models_vc), \n",
    "                          'subjects': np.tile(subjects_dict['243'], num_models_vc)}).dropna()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "full_model = 'SP+SL+WORD+GPT2-XL'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def modified_r2_and_idxs(r2_stacked_pd, BL_str, remove_str_simple, nested_name, full_name):\n",
    "    \n",
    "    BL_model = remove_neg_r2(r2_stacked_pd.loc[r2_stacked_pd.Model==BL_str]).reset_index()\n",
    "    print(np.unique(r2_stacked_pd.loc[~r2_stacked_pd.Model.str.contains(remove_str_simple)].reset_index().Model))\n",
    "    nested_model, max_indices_nested = max_across_nested(r2_stacked_pd.loc[~r2_stacked_pd.Model.str.contains(remove_str_simple)].reset_index(), nested_name)\n",
    "    full_model, max_indices_full = max_across_nested(r2_stacked_pd.loc[r2_stacked_pd.Model.str.contains(BL_str)].reset_index(), full_name)\n",
    "    \n",
    "    return BL_model, nested_model, full_model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "BL_mod_384, non_BL_mod_384, full_mod_384 = \\\n",
    "            modified_r2_and_idxs(vc_PD_384, LLM_name, remove_str_simple, simple_model, full_model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "full_model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "BL_mod_243, non_BL_mod_243, full_mod_243 = \\\n",
    "            modified_r2_and_idxs(vc_PD_243, LLM_name, remove_str_simple, simple_model, full_model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "modified_384 = pd.concat((non_BL_mod_384, full_mod_384))\n",
    "modified_243 = pd.concat((non_BL_mod_243, full_mod_243))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "default_palette = sns.color_palette(\"deep\")\n",
    "default_palette"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_across_subjects(dict_pd_merged, figurePath, selected_networks, yticks=None, saveName=None,\n",
    "                         color_palette=None, remove_auditory=True, hue_order=None, \n",
    "                         order=None, clip_zero=True, draw_lines=False, plot_legend=False, ms=10, gpt2_perf=None, plot_legend_under=False):\n",
    "    \n",
    "    if clip_zero:\n",
    "        dict_pd_merged['r2'] = np.where(dict_pd_merged['r2']<0, 0, dict_pd_merged['r2'])\n",
    " \n",
    "    dict_pd_with_all = dict_pd_merged.copy()\n",
    "    pattern = '|'.join(selected_networks)\n",
    "    dict_pd_merged = dict_pd_merged.loc[dict_pd_merged['Network'].str.contains(pattern)]\n",
    "        \n",
    "    subject_avg_pd = dict_pd_merged.groupby(['subjects', 'Network', 'Model']).mean()\n",
    "    \n",
    "    #plt.figure(figsize=(14,10))\n",
    "    sns.set_theme()\n",
    "    sns.set_style(\"white\")\n",
    "    sns.despine()\n",
    "    \n",
    "    fig, ax = plt.subplots(1,1, figsize=(6,6))\n",
    "    \n",
    "    sns.stripplot(data=subject_avg_pd, x='Network', y='r2', hue='Model', dodge=True, palette=color_palette, \n",
    "                   size=ms, hue_order=hue_order, order=order, ax=ax,  legend=plot_legend)\n",
    "    \n",
    "    if draw_lines:\n",
    "        for i in range(0, len(selected_networks)*2, 2):\n",
    "            locs1 = ax.get_children()[i].get_offsets()\n",
    "            locs2 = ax.get_children()[i+1].get_offsets()\n",
    "            for i in range(locs1.shape[0]):\n",
    "                x = [locs1[i, 0], locs2[i, 0]]\n",
    "                y = [locs1[i, 1], locs2[i, 1]]\n",
    "                ax.plot(x, y, color=\"black\", alpha=0.2)\n",
    "    \n",
    "    sns.barplot(data=subject_avg_pd, x='Network', y='r2', hue='Model', palette=color_palette, \n",
    "                alpha=0.5, errorbar=None, hue_order=hue_order, order=order, ax=ax, legend=False)\n",
    "    \n",
    "    if gpt2_perf is not None:\n",
    "        plt.axhline(gpt2_perf, linestyle='--', color='gray', linewidth=4)\n",
    "    \n",
    "    sns.despine()\n",
    "    \n",
    "    if plot_legend:\n",
    "        if plot_legend_under:\n",
    "            plt.legend(fontsize=25,frameon=False, bbox_to_anchor=(0.2, -0.10))\n",
    "        else:\n",
    "            plt.legend(fontsize=20,frameon=False, bbox_to_anchor=(1, 1), loc='upper left')\n",
    "\n",
    "    ax.set_ylabel('R2' + r\"$_{oos}$\", fontsize=40)\n",
    "    ax.set_xticks([])\n",
    "    ax.set_xlabel('')\n",
    "    if yticks is not None:\n",
    "        ax.set_yticks(yticks)\n",
    "    plt.tick_params(axis='x', labelsize=30) \n",
    "    plt.tick_params(axis='y', labelsize=30) \n",
    "   \n",
    "    if saveName is not None:\n",
    "        plt.savefig(f'{figurePath}{saveName}.pdf', bbox_inches='tight')\n",
    "        plt.savefig(f'{figurePath}{saveName}.png', bbox_inches='tight', dpi=300)\n",
    "    plt.show()\n",
    "    \n",
    "    return subject_avg_pd, dict_pd_merged, dict_pd_with_all"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "_, _, _ = plot_across_subjects(modified_384.copy(), figurePath=figurePath, selected_networks=['language'],\n",
    "                                             saveName='384_across_subjects_xl_mod', hue_order=[simple_model, full_model], \n",
    "                                             yticks=[0,0.08], \n",
    "                                                order=['language'], clip_zero=True, color_palette=[default_palette[1], default_palette[9]], draw_lines=True, ms=15, plot_legend=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "_, _, _ = plot_across_subjects(modified_243.copy(), figurePath=figurePath, selected_networks=['language'],\n",
    "                                             saveName='243_across_subjects_xl_mod', hue_order=[simple_model, full_model], \n",
    "                                             yticks=[0,0.10], \n",
    "                                                order=['language'], clip_zero=True, color_palette=[default_palette[1], default_palette[9]], draw_lines=True, ms=15, plot_legend=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "SPW_corrected_r2_243 = modified_243.loc[modified_243.Model=='SP+SL+WORD'].r2.values\n",
    "BLSPW_corrected_r2_243 = modified_243.loc[modified_243.Model==full_model].r2.values"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "BLSPW_corrected_r2_243.max()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plotting_folder = \"/data/LLMs/Pereira/plotting_data/\"\n",
    "subjects, stored_data_exp_WN_pos = save_fMRI_simple(SPW_corrected_r2_243, exp='243', \n",
    "                        subjects_to_plot=np.unique(subjects_dict['243']), \n",
    "                        subjects_all=subjects_dict['243'], save_name='SP+SL+WORD_243') \n",
    "\n",
    "subjects, stored_data_exp_WN_pos = save_fMRI_simple(BLSPW_corrected_r2_243, exp='243', \n",
    "                        subjects_to_plot=np.unique(subjects_dict['243']), \n",
    "                        subjects_all=subjects_dict['243'], save_name='FULL_243')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plotting.plot_glass_brain(f'{plotting_folder}SP+SL+WORD_243_426.nii', \n",
    "                            colorbar=False, display_mode='xz', vmax=0.4, output_file=f'{figurePath}glass_brain_SP+SL+WOWRD_243_426.pdf')\n",
    "plotting.plot_glass_brain(f'{plotting_folder}FULL_243_426.nii', \n",
    "                            colorbar=True, display_mode='xz', vmax=0.4, output_file=f'{figurePath}glass_brain_FULL_243_426.pdf')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "SPW_corrected_r2_384 = modified_384.loc[modified_384.Model=='SP+SL+WORD'].r2.values\n",
    "BLSPW_corrected_r2_384 = modified_384.loc[modified_384.Model==full_model].r2.values"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "BLSPW_corrected_r2_384.max()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plotting_folder = \"/data/LLMs/Pereira/plotting_data/\"\n",
    "subjects, stored_data_exp_WN_pos = save_fMRI_simple(SPW_corrected_r2_384, exp='384', \n",
    "                        subjects_to_plot=np.unique(subjects_dict['384']), \n",
    "                        subjects_all=subjects_dict['384'][non_nan_384].squeeze(), save_name='SP+SL+WORD_384') \n",
    "\n",
    "subjects, stored_data_exp_WN_pos = save_fMRI_simple(BLSPW_corrected_r2_384, exp='384', \n",
    "                        subjects_to_plot=np.unique(subjects_dict['384']), \n",
    "                        subjects_all=subjects_dict['384'][non_nan_384].squeeze(), save_name='FULL_384')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plotting.plot_glass_brain(f'{plotting_folder}SP+SL+WORD_384_426.nii', \n",
    "                            colorbar=False, display_mode='xz', vmax=0.35, output_file=f'{figurePath}glass_brain_SP+SL+WOWRD_384_426.pdf')\n",
    "plotting.plot_glass_brain(f'{plotting_folder}FULL_384_426.nii', \n",
    "                            colorbar=True, display_mode='xz', vmax=0.35, output_file=f'{figurePath}glass_brain_FULL_384_426.pdf')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "llama",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.4"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
