{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def construct_splits_pereira(X, y, data_labels, alphas, feature_grouper, \n",
    "                             n_iter, use_kernelized, dataset, features_list, \n",
    "                             test_indices_arr, experiments, n_passages_384, \n",
    "                             n_passages_243, exp):\n",
    "    \n",
    "    # select which experiment we'll use for testing for this round of k-fold    \n",
    "    for test_experiment in experiments:\n",
    "        \n",
    "        print(f\"Test experiment: {test_experiment}\")\n",
    "        \n",
    "        if test_experiment == '384':\n",
    "            n_passages = n_passages_384\n",
    "        else:\n",
    "            n_passages = n_passages_243\n",
    "        \n",
    "        # select which passage we'll use for testing \n",
    "        for test_passage_number in n_passages: \n",
    "            \n",
    "            test_indices = split_by_exp_passage_num(test_experiment, test_passage_number, data_labels)\n",
    "            test_labels = data_labels[test_indices]\n",
    "            train_indices = np.setdiff1d(np.arange(data_labels.shape[0]), test_indices)\n",
    "            train_labels = data_labels[train_indices]\n",
    "            \n",
    "            test_indices_arr.extend(test_indices)\n",
    "            \n",
    "            # now we'll do k-fold validation for each of the remaining passage nums in the \n",
    "            # test experiment to determine the optimal alpha parameter \n",
    "            X_train = X[train_indices]\n",
    "            y_train = y[train_indices]\n",
    "            X_test = X[test_indices]\n",
    "            y_test = y[test_indices]\n",
    "            num_voxels = y_train.shape[1]\n",
    "\n",
    "            if test_experiment == '384':\n",
    "                n_passages = n_passages_384\n",
    "                val_passages = np.concatenate((np.delete(n_passages, test_passage_number), n_passages_243))\n",
    "                val_exp_names = np.concatenate((np.repeat('384', n_passages.shape[0]-1), \n",
    "                                        np.repeat('243', n_passages_243.shape[0])))\n",
    "            else:\n",
    "                n_passages = n_passages_243\n",
    "                val_passages = np.concatenate((np.delete(n_passages, test_passage_number), n_passages_384))\n",
    "                val_exp_names = np.concatenate((np.repeat('243', n_passages.shape[0]-1), \n",
    "                                        np.repeat('384', n_passages_384.shape[0])))\n",
    "                \n",
    "            if exp != 'both':\n",
    "                idxs_exp = np.argwhere(val_exp_names==exp)\n",
    "                val_passages = val_passages[idxs_exp].squeeze()\n",
    "                val_exp_names = val_exp_names[idxs_exp].squeeze()\n",
    "\n",
    "            mse_test, mse_intercept, val_perf, y_pred, y_pred_intercept = run_himalayas(X_train, \n",
    "                                            y_train, X_test, y_test, alphas, \n",
    "                                            train_labels, feature_grouper, n_iter, use_kernelized, \n",
    "                                            dataset, features_list, val_passages, val_exp_names)\n",
    "        \n",
    "            val_stored.append(val_perf)\n",
    "        \n",
    "            mse_stored_intercept_only.append(mse_intercept)\n",
    "            mse_stored.append(mse_test)\n",
    "            y_hat_folds.append(y_pred)\n",
    "            y_test_folds.append(y_test)\n",
    "            test_fold_size.append(X_test.shape[0])\n",
    "            \n",
    "            return mse_stored_intercept_only, mse_stored, y_hat_folds, y_test_folds, test_fold_size\n",
    "            \n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "   \n",
    "        if dataset == 'pereira':\n",
    "           \n",
    "        \n",
    "        elif dataset == 'federonko':\n",
    "            \n",
    "            sentence_length = 8\n",
    "            sentence_num = 52\n",
    "            split_size = 32\n",
    "            \n",
    "            for i in range(0, sentence_num*sentence_length, split_size):\n",
    "                \n",
    "                test_indices = np.arange(i, i+split_size)\n",
    "                train_indices = np.setdiff1d(np.arange(sentence_num*sentence_length), test_indices)\n",
    "                test_labels = data_labels[test_indices]\n",
    "                train_labels = data_labels[train_indices]\n",
    "\n",
    "                X_train = X[train_indices]\n",
    "                y_train = y[train_indices]\n",
    "                X_test = X[test_indices]\n",
    "                y_test = y[test_indices]\n",
    "                \n",
    "                mse_test, mse_intercept, val_perf, y_pred, y_pred_intercept = run_himalayas(X_train, \n",
    "                                                y_train, X_test, y_test, alphas, \n",
    "                                                train_labels, feature_grouper, n_iter, use_kernelized, \n",
    "                                                dataset, features_list, val_passages=None, val_exp_names=None)\n",
    "            \n",
    "                val_stored.append(val_perf)\n",
    "                mse_stored_intercept_only.append(mse_intercept)\n",
    "                mse_stored.append(mse_test)\n",
    "                y_hat_folds.append(y_pred)\n",
    "                y_test_folds.append(y_test) \n",
    "                test_fold_size.append(X_test.shape[0])   \n",
    "            \n",
    "        elif dataset == 'blank':\n",
    "            \n",
    "            num_samples = data_labels.shape[0]\n",
    "            \n",
    "            for test_story in np.unique(data_labels):\n",
    "                \n",
    "                test_indices = np.argwhere(data_labels==test_story)\n",
    "                train_indices = np.setdiff1d(np.arange(num_samples), test_indices)\n",
    "                test_labels = np.squeeze(data_labels[test_indices])\n",
    "                train_labels = np.squeeze(data_labels[train_indices])\n",
    "                \n",
    "                X_train = np.squeeze(X[train_indices])\n",
    "                y_train = np.squeeze(y[train_indices])\n",
    "                X_test = np.squeeze(X[test_indices])\n",
    "                y_test = np.squeeze(y[test_indices])\n",
    "                \n",
    "                mse_test, mse_intercept, val_perf, y_pred = run_himalayas(X_train, \n",
    "                                            y_train, X_test, y_test, alphas, \n",
    "                                            train_labels, feature_grouper, n_iter, use_kernelized, \n",
    "                                            dataset, features_list, val_passages=None, val_exp_names=None)\n",
    "        \n",
    "                \n",
    "                val_stored.append(val_perf)\n",
    "                \n",
    "                mse_stored_intercept_only.append(mse_intercept)\n",
    "                mse_stored.append(mse_test)\n",
    "                y_hat_folds.append(y_pred)\n",
    "                y_test_folds.append(y_test)\n",
    "                test_fold_size.append(X_test.shape[0])"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "llama",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.4"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
