{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import brainio\n",
    "import sys\n",
    "sys.path.append('/home3/name/what-is-brainscore/')\n",
    "from helper_funcs import combine_MSE_across_folds\n",
    "from sklearn.metrics import mean_squared_error\n",
    "from matplotlib import pyplot as plt\n",
    "import copy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "base = '/home3/name/what-is-brainscore/results_all/'\n",
    "resultsFolder = f'{base}results_pereira/'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "nc_file_pereira = '/home3/name/what-is-brainscore/pereira_data/Pereira_data.nc'\n",
    "pereira_data = brainio.assemblies.DataAssembly.from_files(nc_file_pereira)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "language_idxs = np.asarray(pereira_data.atlas=='language')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = 'pereira'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "key_intercept = 'intercept_only'\n",
    "model_res_intercept = dict(np.load(f\"{resultsFolder}{dataset}_{key_intercept}.npz\"))\n",
    "y_test = model_res_intercept['y_test_folds']\n",
    "mse_intercept_only = combine_MSE_across_folds(model_res_intercept['mse_stored'], dataset='pereira')\n",
    "mse_intercept_only.mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "key = 'gpt_glove_surprisal_position_kern'\n",
    "val = 'layer1_1000'\n",
    "model_res = dict(np.load(f\"{resultsFolder}{dataset}_{key}_{val}.npz\"))\n",
    "key = \"glove_surprisal_position\"\n",
    "val = 'layer1_1000'\n",
    "model_res_glove = dict(np.load(f\"{resultsFolder}{dataset}_{key}_{val}.npz\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "y_hat_LI = model_res_glove['y_hat_folds']\n",
    "y_hat_LI_gpt = model_res['y_hat_folds']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# real R2 for each model\n",
    "mse_LI_gpt = mean_squared_error(y_hat_LI_gpt, y_test, multioutput='raw_values')\n",
    "mse_LI = mean_squared_error(y_hat_LI, y_test, multioutput='raw_values')\n",
    "\n",
    "r2_LI_gpt = 1 - mse_LI_gpt/mse_intercept_only\n",
    "r2_LI = 1 - mse_LI/mse_intercept_only\n",
    "\n",
    "print(r2_LI_gpt.mean())\n",
    "print(r2_LI.mean())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_samples = y_test.shape[0]\n",
    "num_voxels = y_test.shape[1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# now try shuffling and recomputing the R2\n",
    "swap_number= int(0.5*num_samples)\n",
    "print(swap_number)\n",
    "num_permutations = 1000\n",
    "swapped_r2_LI_gpt = np.zeros((num_permutations, num_voxels))\n",
    "for i in range(num_permutations):\n",
    "    if i % 100 == 0:\n",
    "        print(i)\n",
    "    y_hat_LI_gpt_swapped = copy.deepcopy(y_hat_LI_gpt)\n",
    "    swap_indices = np.random.choice(num_samples, swap_number, replace=False)\n",
    "    y_hat_LI_gpt_swapped[swap_indices] = y_hat_LI[swap_indices]\n",
    "    mse_LI_gpt_swapped = mean_squared_error(y_hat_LI_gpt_swapped, y_test, multioutput='raw_values')\n",
    "    r2_LI_gpt_swapped = 1 - mse_LI_gpt_swapped/mse_intercept_only\n",
    "    swapped_r2_LI_gpt[i] = r2_LI_gpt_swapped"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "r2_LI_gpt_swapped.mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "r2_diff = r2_LI_gpt - r2_LI\n",
    "r2_diff_sorted = np.sort(r2_diff)\n",
    "idx = 40000\n",
    "print(r2_diff_sorted[idx])\n",
    "r2_diff_small_diff = np.argwhere(r2_diff==r2_diff_sorted[idx])\n",
    "\n",
    "# Let's see what the p-value for this small difference in r2 is \n",
    "swapped_scores_small_diff = swapped_r2_LI_gpt[:, r2_diff_small_diff].squeeze()\n",
    "plt.hist(swapped_scores_small_diff)\n",
    "plt.axvline(r2_diff_sorted[idx], color='r')\n",
    "plt.xlabel(\"R2\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def r2_difference(model1_preds, model2_preds, ytest, num_permutations=1000):\n",
    "    \n",
    "    '''\n",
    "    ndarray model1_preds: predictions from first model\n",
    "    ndarray model2_preds: predictions from second, nested model\n",
    "    ndarray ytest: groundtruth data \n",
    "    int num_permutations: number of swaps to do \n",
    "    \n",
    "    Swaps 50% of predictions between model1 and model2, and then recomputes out of sample r2.\n",
    "    This is done num_permutation times to create a distribution.\n",
    "    '''\n",
    "\n",
    "    num_samples = y_test.shape[0]\n",
    "    num_voxels = y_test.shape[1]\n",
    "    \n",
    "    # now try shuffling and recomputing the R2\n",
    "    swap_number= int(0.5*num_samples)\n",
    "\n",
    "    swapped_r2 = np.zeros((num_permutations, num_voxels))\n",
    "    \n",
    "    for i in range(num_permutations):\n",
    "        \n",
    "        if i % 100 == 0:\n",
    "            print(i)\n",
    "            \n",
    "        model1_preds_swapped = copy.deepcopy(model1_preds)\n",
    "        swap_indices = np.random.choice(num_samples, swap_number, replace=False)\n",
    "        model1_preds_swapped[swap_indices] = model2_preds[swap_indices]\n",
    "        mse_LI_gpt_swapped = mean_squared_error(model1_preds_swapped, y_test, multioutput='raw_values')\n",
    "        r2_swapped = 1 - mse_LI_gpt_swapped/mse_intercept_only\n",
    "        swapped_r2[i] = r2_swapped\n",
    "        \n",
    "    \n",
    "        \n",
    "    \n",
    "\n",
    "        \n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.argwher"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.argwhere(r2_LI_gpt>swapped_r2_LI_gpt.mean(axis=0)).shape[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "swapped_r2_LI_gpt.mean(axis=0).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.argwhere(r2_LI_gpt > r2_LI).shape[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.hist(np.mean(swapped_r2_LI_gpt,axis=1))\n",
    "plt.axvline(r2_LI_gpt.mean(), linestyle='--', color='r')\n",
    "plt.ylabel(\"Count\")\n",
    "plt.xlabel(\"Mean of swapped r2 across voxels\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "def permute_matrix_in_blocks(y_pred, mse_intercept_only, block_size, num_permutations=1000):\n",
    "    \n",
    "    permuted_r2 = np.zeros((num_permutations, y_pred.shape[1]))\n",
    "    \n",
    "    # Get the number of rows in the y_pred\n",
    "    num_rows = y_pred.shape[0]\n",
    "    # Define the block size\n",
    "    \n",
    "    # Calculate the number of blocks\n",
    "    num_blocks = num_rows // block_size\n",
    "    \n",
    "    # Create a list of block indices and shuffle them\n",
    "    block_indices = list(range(num_blocks))\n",
    "    \n",
    "    real_r2 = 1 - mean_squared_error(y_pred, y_test, multioutput='raw_values')/mse_intercept_only\n",
    "    \n",
    "    print(\"R2 mean: \", real_r2.mean())\n",
    "    \n",
    "    for n in range(num_permutations):\n",
    "        \n",
    "        if n % 100 == 0:\n",
    "            print(n)\n",
    "        \n",
    "        np.random.shuffle(block_indices)\n",
    "        \n",
    "        # Create an empty y_pred to store the shuffled rows\n",
    "        shuffled_y_pred = np.empty_like(y_pred)\n",
    "        \n",
    "        # Iterate through the shuffled block indices\n",
    "        for i, block_index in enumerate(block_indices):\n",
    "            # Calculate the start and end row indices for the current block\n",
    "            start_row = block_index * block_size\n",
    "            end_row = start_row + block_size\n",
    "            \n",
    "            # Copy the rows from the original y_pred to the shuffled y_pred\n",
    "            shuffled_y_pred[i * block_size : (i + 1) * block_size, :] = y_pred[start_row:end_row, :]\n",
    "            \n",
    "        mse_model = mean_squared_error(shuffled_y_pred, y_test, multioutput='raw_values')\n",
    "        \n",
    "        permuted_r2[n] = 1 - mse_model/mse_intercept_only\n",
    "                    \n",
    "    return permuted_r2, real_r2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "permuted_r2, real_r2  = permute_matrix_in_blocks(y_hat_folds, mse_intercept_only, \n",
    "                                          block_size=3, num_permutations=1000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "real_r2_z_scored = (real_r2 - np.mean(permuted_r2, axis=0))/np.std(permuted_r2,axis=0)\n",
    "plt.hist(real_r2_z_scored)\n",
    "plt.show()\n",
    "plt.hist(permuted_r2[:, 0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# find fraction of permuted shuffled r2 that are greater than the real r2\n",
    "# to compute p-value\n",
    "num_permutations = 1000\n",
    "real_less_than_shuffled = real_r2 - permuted_r2\n",
    "real_less_than_shuffled[real_less_than_shuffled>=0] = 0\n",
    "real_less_than_shuffled[real_less_than_shuffled<0] = 1\n",
    "p_vals = np.sum(real_less_than_shuffled,axis=0)/num_permutations\n",
    "print(p_vals.shape)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.mean(p_vals)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "subjects = np.unique(pereira_data.subject)\n",
    "networks = np.unique(pereira_data.atlas)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "col_to_coord_1 = pereira_data.col_to_coord_1\n",
    "col_to_coord_2 = pereira_data.col_to_coord_2\n",
    "col_to_coord_3 = pereira_data.col_to_coord_3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "p_vals.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "p_vals_stored = {}\n",
    "r2_vals_stored = {}\n",
    "SPM_dim = (79,95,69)\n",
    "\n",
    "store_scores_by_network = {}\n",
    "p_vals_426 = {}\n",
    "for n in networks:\n",
    "    store_scores_by_network[n] = []\n",
    "    \n",
    "    p_vals_426[n] = []\n",
    "    for s in subjects:\n",
    "        \n",
    "    \n",
    "        print(s, n)\n",
    "        \n",
    "        p_vals_stored[f'{s}_{n}'] = np.full(SPM_dim, np.nan)\n",
    "        r2_vals_stored[f'{s}_{n}'] = np.full(SPM_dim, np.nan)\n",
    "        \n",
    "        \n",
    "        subj_idxs = np.asarray((pereira_data.subject==s))\n",
    "        network_idxs = np.asarray((pereira_data.atlas==n))\n",
    "        subj_network_idxs = np.logical_and(subj_idxs, network_idxs)\n",
    "        \n",
    "        p_vals_sn = p_vals[subj_network_idxs]\n",
    "        r2_vals_sn = real_r2[subj_network_idxs]\n",
    "        \n",
    "        store_scores_by_network[n].append(r2_vals_sn.mean())\n",
    "        \n",
    "        col_to_coord_1_sn = np.array(col_to_coord_1[subj_network_idxs])\n",
    "        col_to_coord_2_sn = np.array(col_to_coord_2[subj_network_idxs])\n",
    "        col_to_coord_3_sn = np.array(col_to_coord_3[subj_network_idxs])\n",
    "        \n",
    "        p_vals_426[n].extend(p_vals_sn)\n",
    "        \n",
    "        for i, (x,y,z) in enumerate(zip(col_to_coord_1_sn, col_to_coord_2_sn, col_to_coord_3_sn)):\n",
    "            p_vals_stored[f'{s}_{n}'][x,y,z] = p_vals_sn[i]\n",
    "            r2_vals_stored[f'{s}_{n}'][x,y,z] = r2_vals_sn[i]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.savez(f'/data/LLMs/Pereira/stats_results/{key}_r2', **r2_vals_stored)\n",
    "np.savez(f'/data/LLMs/Pereira/stats_results/{key}_pval', **p_vals_stored)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "store_scores_by_network['language']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.hist(p_vals_426['language'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.hist(p_vals_426['auditory'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "llama",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.4"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
