{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import brainio\n",
    "import numpy as np\n",
    "base = '/home3/name/what-is-brainscore/'\n",
    "from matplotlib import pyplot as plt\n",
    "import os\n",
    "from sklearn.metrics import mean_squared_error\n",
    "import sys\n",
    "sys.path.append('/home3/name/what-is-brainscore/')\n",
    "from helper_funcs import *\n",
    "from scipy.stats import pearsonr\n",
    "import seaborn as sns\n",
    "import pandas as pd\n",
    "import statsmodels.formula.api as smf\n",
    "import matplotlib"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def stack_features(model_dict, filename, saveFolder='/home3/name/what-is-brainscore/data_processed/', dataset='pereira'):\n",
    "\n",
    "    saveFolder = f'{saveFolder}{dataset}/'\n",
    "    stacked_name = ''\n",
    "    stored_acts = []\n",
    "    for key, values in model_dict.items():\n",
    "        for value in values:\n",
    "            val = np.load(f'{saveFolder}X_{key}.npz')[value]\n",
    "            if len(val.shape) == 1:\n",
    "                val = np.expand_dims(val, axis=-1)\n",
    "            stored_acts.append(val)\n",
    "            print(val.shape)\n",
    "            \n",
    "    X_stacked = {'layer1': np.hstack(stored_acts)}\n",
    "    np.savez(f'{saveFolder}X_{filename}', **X_stacked)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "stack_features({'gpt2-xl': ['layer_21'], 'synt-kauf-xl-21': ['layer1'], 'EXP': ['layer1'], \n",
    "                'word-num': ['layer1']}, filename='gpt2-xl-bil_static_word_EWN')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "stack_features({'gpt2-xl': ['layer_21']}, filename='gpt2-xl-bil')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "stack_features({'positional_simple': ['layer1'], 'word-num': ['layer1']}, filename='positional_WN')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "stack_features({'gpt2-xl-sp-static': ['layer1'], 'gpt2-xl': ['layer_21'],\n",
    "                'positional_simple': ['layer1'], 'word-num': ['layer1'], 'EXP':['layer1']}, filename='gpt2-xl-static_bil_DEM_u')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "stack_features({'gpt2-xl-sp-static': ['layer1'], 'x2static': ['layer1'], 'syntax-kauf': ['layer1'], \n",
    "                'positional_simple': ['layer1'], 'word-num': ['layer1'], 'EXP':['layer1']}, filename='gpt2-xl-static_X2SYO_PWE')\n",
    "\n",
    "stack_features({'gpt2-xl-sp-static': ['layer1'], 'gpt2-xl': ['layer_21'], 'x2static': ['layer1'], 'syntax-kauf': ['layer1'],\n",
    "                'positional_simple': ['layer1'], 'word-num': ['layer1'], 'EXP':['layer1']}, filename='gpt2-xl-static_bil_X2SYO_PWE')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "stack_features({'gpt2-xl-sp-static': ['layer1'], 'x2static': ['layer1'], 'synt-kauf-xl-l21': ['layer1'], \n",
    "                'positional_simple': ['layer1'], 'word-num': ['layer1'], 'EXP':['layer1']}, filename='gpt2-xl-static_X2SY_PWE')\n",
    "\n",
    "\n",
    "stack_features({'gpt2-xl-sp-static': ['layer1'], 'gpt2-xl': ['layer_21'], 'x2static': ['layer1'], 'synt-kauf-xl-l21': ['layer1'],\n",
    "                'positional_simple': ['layer1'], 'word-num': ['layer1'], 'EXP':['layer1']}, filename='gpt2-xl-static_bil_X2SY_PWE')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "stack_features({'gpt2-xl-sp-static': ['layer1'], 'WORD': ['layer1'], 'synt-kauf-xl-l21': ['layer1'], \n",
    "                'positional_simple': ['layer1'], 'word-num': ['layer1'], 'EXP':['layer1']}, filename='gpt2-xl-static_WSY_PWE')\n",
    "\n",
    "\n",
    "stack_features({'gpt2-xl': ['layer_21'], 'gpt2-xl-sp-static': ['layer1'], 'WORD': ['layer1'], 'synt-kauf-xl-l21': ['layer1'],\n",
    "                'positional_simple': ['layer1'], 'word-num': ['layer1'], 'EXP':['layer1']}, filename='gpt2-xl-bil_static_WSY_PWE')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "stack_features({'gpt2-xl-sp-static': ['layer1'], 'gpt2-xl': ['layer_21']}, filename='gpt2-xl_with_static')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_name = 'gpt2-xl'\n",
    "save_model_name = 'gpt2xl'\n",
    "N = 18"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "stack_features({'word-num': ['layer1'], 'positional_simple': ['layer1']}, filename='word-num_pos')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "stack_features({'phrase': ['layer1'], 'token-num': ['layer1'], 'positional_simple': ['layer1']}, filename='phrase_token-num_pos')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in range(N):\n",
    "    stack_features({f'{model_name}-untrained-sp_m{i}': ['static']}, filename=f'{save_model_name}-ut_static_m{i}')\n",
    "    stack_features({f'{model_name}-untrained-sp_m{i}': ['lua']}, filename=f'{save_model_name}-ut_lua_m{i}')\n",
    "    stack_features({f'{model_name}-untrained-sp_m{i}': ['lua', 'static']}, filename=f'{save_model_name}-ut_lua_static_m{i}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "N = 11\n",
    "best_layer = []\n",
    "best_layer_lang = []\n",
    "model_name = 'gpt2-xl'\n",
    "exp = '384'\n",
    "for i in range(N):\n",
    "    keys = list(dict(np.load(f'/home3/name/what-is-brainscore/data_processed/pereira/X_{model_name}-untrained-sp-{exp}_m{i}.npz')).keys())\n",
    "    bil_all = [k for k in keys if 'all' in k][0]\n",
    "    bil_lang = [k for k in keys if 'lang' in k][0]\n",
    "    best_layer.append(bil_all)\n",
    "    best_layer_lang.append(bil_lang)\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in range(N):\n",
    "    bil = best_layer[i]\n",
    "    #stack_features({f'{model_name}-untrained-sp_m{i}': [bil],  'EXP': ['layer1'], 'word-num': ['layer1'], 'positional_simple': ['layer1']}, \n",
    "    #               filename=f'{save_model_name}-ut_bil_EXPWNPOS_m{i}')\n",
    "    bil_lang = best_layer_lang[i]\n",
    "    stack_features({f'{model_name}-untrained-sp-{exp}_m{i}': [bil_lang], 'positional_simple': ['layer1'],  'word-num': ['layer1']}, \n",
    "                   filename=f'{model_name}-ut_bil-lang_POSWN_{exp}_m{i}')\n",
    "    stack_features({f'{model_name}-untrained-sp-{exp}_m{i}': [bil_lang], 'positional_simple': ['layer1']}, \n",
    "                   filename=f'{model_name}-ut_bil-lang_POS_{exp}_m{i}')\n",
    "    stack_features({f'{model_name}-untrained-sp-{exp}_m{i}': [bil_lang], 'word-num': ['layer1']}, \n",
    "                   filename=f'{model_name}-ut_bil-lang_WN_{exp}_m{i}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "stack_features({'EXP': ['layer1'], 'word-num': ['layer1'], 'positional_simple': ['layer1']}, \n",
    "                   filename=f'EXPWNPOS')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in range(N):\n",
    "    bil = best_layer[i]\n",
    "    print(bil)\n",
    "    stack_features({f'gpt2-large-untrained-sp_m{i}': ['lua', 'static'], 'token-num': ['layer1']}, filename=f'gpt2l-ut_lua_static_token-num_m{i}')\n",
    "    stack_features({f'gpt2-large-untrained-sp_m{i}': [bil, 'lua', 'static'], 'token-num': ['layer1']}, filename=f'gpt2l-ut_bil_lua_static_token-num_m{i}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "stack_features({'word-num': ['layer1'], 'gpt2-large-sp-hfgpt': ['embedding+pos'], 'sense': ['layer1']})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "find_best_layer('gpt2-large-untrained-sp', resultsFolder='/home3/name/what-is-brainscore/results_all/results_pereira/', exclude_str=['384', '243'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "stack_features({'sense-deref': ['layer1'],\n",
    "                'gpt2-large-static-2': ['layer1'], \n",
    "                'syntax': ['layer1']})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "stack_features({'sense-deref': ['layer1'],\n",
    "                'gpt2-large-static-2': ['layer1'], \n",
    "                'EXP': ['layer1'],\n",
    "                'word-num': ['layer1'], \n",
    "                'first-sent-384': ['layer1']})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "stack_features({'word-num': ['layer1'], 'EXP': ['layer1'], 'sense': ['layer1'], \n",
    "                'gpt2-large-sp-hfgpt': ['embedding'], 'gpt2-large-hfgpt' : ['encoder.h.18']})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "interp = dict(np.load(\"/home3/name/what-is-brainscore/temp_data_all/temp_data_pereira/X_positional_simple_word-num_EXP.npz\"))['layer1']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_num = np.arange(30)\n",
    "model_dict = {}\n",
    "model_dict_all = {}\n",
    "for mn in model_num:\n",
    "    \n",
    "    mn_activity = dict(np.load(f'/home3/name/what-is-brainscore/temp_data_all/temp_data_pereira/X_gpt2-large-untrained-sp-hfgpt_{mn}.npz'))\n",
    "    \n",
    "    bil_key = [key for key in mn_activity.keys() if 'encoder' in key and 'no-nonlin' not in key]\n",
    "    print(bil_key)\n",
    "    # wr + lua + static embed + bil\n",
    "    interp_and_bil = np.hstack((mn_activity[bil_key[0]], interp))\n",
    "\n",
    "    model_dict_all[str(mn)] = interp_and_bil\n",
    "    \n",
    "print(interp_and_bil.shape)\n",
    "\n",
    "np.savez(\"/home3/name/what-is-brainscore/temp_data_all/temp_data_pereira/X_BIL+interp\", **model_dict_all)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "interp2 = dict(np.load(\"/home3/name/what-is-brainscore/temp_data_all/temp_data_pereira/X_positional_simple_word-num.npz\"))['layer1']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_num = np.arange(30)\n",
    "model_dict = {}\n",
    "model_dict_all_384 = {}\n",
    "model_dict_all_243 = {}\n",
    "interp_keys = ['attn_scaled_l1', 'embedding+pos']\n",
    "for mn in model_num:\n",
    "    \n",
    "    mn_activity_384 = dict(np.load(f'/home3/name/what-is-brainscore/temp_data_all/temp_data_pereira/X_gpt2-large-untrained-sp-hfgpt_{mn}_384.npz'))\n",
    "    mn_activity_243 = dict(np.load(f'/home3/name/what-is-brainscore/temp_data_all/temp_data_pereira/X_gpt2-large-untrained-sp-hfgpt_{mn}_243.npz'))\n",
    "\n",
    "    bil_key_384 = [key for key in mn_activity_384.keys() if 'encoder' in key and 'no-nonlin' not in key]\n",
    "    bil_key_243 = [key for key in mn_activity_243.keys() if 'encoder' in key and 'no-nonlin' not in key]\n",
    "    \n",
    "    print(bil_key_384, bil_key_243)\n",
    "    \n",
    "    interp_and_bil_384 = np.hstack((mn_activity_384[bil_key_384[0]], interp2))\n",
    "    interp_and_bil_243 = np.hstack((mn_activity_243[bil_key_243[0]], interp2))\n",
    "    \n",
    "    model_dict_all_384[str(mn)] = interp_and_bil_384\n",
    "    model_dict_all_243[str(mn)] = interp_and_bil_243\n",
    "    \n",
    "print(interp_and_bil_384.shape, interp_and_bil_243.shape)\n",
    "\n",
    "np.savez(\"/home3/name/what-is-brainscore/temp_data_all/temp_data_pereira/X_BIL+interp_384\", **model_dict_all_384)\n",
    "np.savez(\"/home3/name/what-is-brainscore/temp_data_all/temp_data_pereira/X_BIL+interp_243\", **model_dict_all_243) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "wr_square_root = np.sqrt(lwr[:, 1])\n",
    "np.savez('/home3/name/what-is-brainscore/temp_data_all/temp_data_pereira/X_wr-sq', **{'layer1': np.vstack((wr_square_root, lwr[:, 1])).T})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_num = np.arange(20)\n",
    "for mn in range(model_num):\n",
    "    stack_features({'gpt2-large-untrained-sp-hfgpt': ['layer1'], 'roberta-large-sp-hfgpt': ['encoder_15']}, dataset='pereira')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "stack_features({'word-rate': ['layer1'], 'gpt2-large-untrained-sp-hfgpt_3': ['embedding+pos', 'encoder.h.3', 'attn_scaled_l1']}, dataset='pereira')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "stack_features({'gpt2-large-sp-hfgpt': ['embedding'], \n",
    "                'positional_simple': ['layer1'], }, dataset='pereira')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "stack_features({'gpt2-large-hfgpt': ['encoder.h.17'], 'letter-word-rate': ['layer1'], 'positional': ['layer1']}, dataset='pereira')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "stack_features({'gpt2-large-untrained-hfgpt': ['embedding+pos_m', 'embedding+pos_m']})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "stack_features({ 'gpt2-large-untrained-hfgpt': ['encoder.h.34', 'embedding']})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "stack_features({'gpt2-large-hfgpt' : ['layer1'], 'positional': ['layer1']})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "stack_features({'letter-word-rate' : ['layer1'], 'gpt2-large-hfgpt-surprisal': ['layer1'], 'glove': ['projection'], 'positional': ['layer1']})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "stack_features({'letter-word-rate' : ['layer1'], 'glove': ['projection'], 'positional': ['layer1']})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "stack_features({'gpt2-large-hfgpt' : ['embedding'], 'letter-word-rate': ['layer1'], \n",
    "                'positional': ['layer1']})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "stack_features({'gpt2-large-hfgpt-sentencewise': ['encoder.h.17'], 'letter-word-rate' : ['layer1'], 'glove': ['projection'], 'positional': ['layer1']})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "stack_features({'skip-thoughts': ['encoder'], 'letter-word-rate' : ['layer1'], 'glove': ['projection'], 'positional': ['layer1']})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "stack_features({'gpt2-large-hfgpt': ['layer1'], 'letter-word-rate' : ['layer1'], 'glove': ['projection'], 'positional': ['layer1']})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "stack_features({'gpt2-large-hfgpt': ['encoder.h.7'], 'letter-word-rate' : ['layer1'], 'positional': ['layer1']})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "best_layer, _, _, _ = find_best_layer('gpt2-xl-untrained')\n",
    "best_layer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "stack_features({'gpt2-xl-untrained': 'encoder.h.8', 'glove-untrained': 'projection', 'positional_simple': 'layer1'})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "stack_features({'glove-untrained': 'projection', 'positional_simple': 'layer1'})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "stack_features({'gpt2-xl-untrained': 'encoder.h.8', 'glove-untrained': 'projection', 'positional': 'layer1'})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_gpt2_xl = np.load(f\"{base}temp_data_all/temp_data/X_gpt2-xl-hfgpt.npz\")\n",
    "X_gpt2_large_20 = X_gpt2_xl['encoder.h.20']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "stack_features({'gpt2-xl-hfgpt': 'encoder.h.20', 'glove' : 'projection',  \n",
    "                'gpt2-xl-hfgpt-sentencewise-surprisal' : 'layer1', 'positional': 'layer1'})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "stack_features({'glove': 'projection', 'gpt2-large-hfgpt-sentencewise-surprisal': 'layer1', 'positional': 'layer1'})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "stack_features({'glove': 'projection', 'gpt2-large-hfgpt-sentencewise-surprisal': 'layer1', 'dlt_max_small': 'layer1', 'positional': 'layer1'})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "stack_features({'dlt_max_small': 'layer1', 'positional': 'layer1'})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "stack_features({'dlt_max_small': 'layer1', 'gpt2-large-hfgpt-sentencewise-surprisal': 'layer1'})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "stack_features({'dlt_max_small': 'layer1', 'glove': 'projection'})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "stack_features({'glove': 'projection', 'gpt2-large-hfgpt-sentencewise-surprisal': 'layer1', 'dlt_max': 'layer1', 'positional': 'layer1'})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "stack_features({'gpt2-large-hfgpt': 'encoder.h.18', 'glove': 'projection', 'gpt2-large-hfgpt-sentencewise-surprisal': 'layer1', 'dlt_max': 'layer1', 'positional': 'layer1'})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "stack_features({'gpt2-large-hfgpt': 'encoder.h.18', 'dlt_max_small': 'layer1'})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "stack_features({'glove': 'projection', 'gpt2-large-hfgpt-sentencewise-surprisal': 'layer1', 'dlt_max_small': 'layer1', 'positional': 'layer1'})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_surprisal_sentence_wise = np.load('/home3/name/what-is-brainscore/temp_data_all/temp_data/X_gpt2-large-hfgpt-sentencewise-surprisal.npz')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_surprisal = np.load('/home3/name/what-is-brainscore/temp_data_all/temp_data/X_gpt2-large-hfgpt-surprisal.npz')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_both = np.hstack((X_surprisal_sentence_wise['layer1'], X_surprisal['layer1']))\n",
    "np.savez('/home3/name/what-is-brainscore/temp_data_all/temp_data/X_gpt2-large-hfgpt-surprisal-both', **{'layer1':X_both})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_surprisal_sentence_wise_np = X_surprisal_sentence_wise['layer1']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_glove = np.load(f'{base}temp_data_all/temp_data/X_glove.npz')\n",
    "X_glove_np = X_glove['projection']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_position_np = create_positional_signal()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_surprisal_sentence_wise_np.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_dlt = np.load('/home3/name/what-is-brainscore/temp_data_all/temp_data/X_dlt.npz')['layer1']\n",
    "X_dlt.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "glove_surprisal_position = np.hstack((X_glove_np, X_surprisal_sentence_wise_np, X_position_np))\n",
    "X_glove_surprisal_positon = {'layer1': glove_surprisal_position}  \n",
    "np.savez(f'{base}temp_data_all/temp_data/X_glove_surprisal_positon', **X_glove_surprisal_positon)  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "glove_surprisal_dlt_position = np.hstack((X_glove_np, X_surprisal_sentence_wise_np, X_dlt, X_position_np))\n",
    "print(glove_surprisal_dlt_position.shape)\n",
    "X_glove_surprisal_dlt_positon = {'layer1': glove_surprisal_dlt_position}  \n",
    "np.savez(f'{base}temp_data_all/temp_data/X_glove_surprisal_dlt_positon', **X_glove_surprisal_dlt_positon)  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "best_file, best_results, val_perf, best_test_results = find_best_layer(model='gpt2-large-hfgpt', \n",
    "                resultsFolder='/home3/name/what-is-brainscore/results_all/results-himalayas')\n",
    "X_gpt2_large = np.load(f\"{base}temp_data_all/temp_data/X_gpt2-large-hfgpt.npz\")\n",
    "X_gpt2_large_18 = X_gpt2_large['encoder.h.18']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "gpt_glove_surprisal_position = np.hstack((X_gpt2_large_18, X_glove_np, X_surprisal_sentence_wise_np, X_position_np))\n",
    "X_gpt_glove_surprisal_positon = {'layer1': gpt_glove_surprisal_position}  \n",
    "np.savez(f'{base}temp_data_all/temp_data/X_gpt_glove_surprisal_position', **X_gpt_glove_surprisal_positon)  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "gpt_glove_surprisal_dlt_position = np.hstack((X_gpt2_large_18, X_glove_np, X_surprisal_sentence_wise_np, X_dlt, X_position_np))\n",
    "X_gpt_glove_surprisal_dlt_positon = {'layer1': gpt_glove_surprisal_dlt_position}  \n",
    "print(gpt_glove_surprisal_dlt_position.shape)\n",
    "np.savez(f'{base}temp_data_all/temp_data/X_gpt_glove_surprisal_dlt_position', **X_gpt_glove_surprisal_dlt_positon)  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dlt_position = np.hstack((X_dlt, X_position_np))\n",
    "X_dlt_positon = {'layer1': dlt_position}  \n",
    "print(gpt_glove_surprisal_dlt_position.shape)\n",
    "np.savez(f'{base}temp_data_all/temp_data/X_gpt_glove_surprisal_dlt_position', **X_gpt_glove_surprisal_dlt_positon)  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "best_file, best_results, val_perf, best_test_results = find_best_layer(model='gpt2-large-hfgpt-sentencewise', \n",
    "                resultsFolder='/home3/name/what-is-brainscore/results_all/results-himalayas')\n",
    "print(best_file)\n",
    "X_gpt2_large_sw = np.load(f\"{base}temp_data_all/temp_data/X_gpt2-large-hfgpt-sentencewise.npz\")\n",
    "X_gpt2_large_sw_17 = X_gpt2_large['encoder.h.17']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "gptsw_position = np.hstack((X_gpt2_large_sw_17 , X_position_np))\n",
    "X_gptsw_positon = {'layer1': gptsw_position}  \n",
    "np.savez(f'{base}temp_data_all/temp_data/X_gptsw_position', **X_gptsw_positon)  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_glove_content = np.load(f'{base}temp_data_all/temp_data/X_glove_content.npz')\n",
    "X_glove_content_np = X_glove_content['layer1']\n",
    "glove_content_position = np.hstack((X_glove_content_np, X_position_np))\n",
    "X_glove_content_positon = {'layer1': glove_content_position}  \n",
    "np.savez(f'{base}temp_data_all/temp_data/X_glove_content_positon', **X_glove_content_positon)  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_glove_prev_sent= np.load(f'{base}temp_data_all/temp_data/X_glove_prev_sent.npz')\n",
    "X_glove_prev_sent_np =X_glove_prev_sent['layer1']\n",
    "glove_prev_sent_position = np.hstack((X_glove_prev_sent_np, X_position_np))\n",
    "X_glove_prev_sent_positon = {'layer1': glove_prev_sent_position}  \n",
    "np.savez(f'{base}temp_data_all/temp_data/X_glove_prev_sent_positon', **X_glove_prev_sent_positon)  "
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "llama",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.4"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
