{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "54a0e7ce",
   "metadata": {},
   "outputs": [],
   "source": [
    "#why is cross val score different than test val split?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "8a898b6b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "module_path = os.path.abspath(os.path.join('..'))\n",
    "if module_path not in sys.path:\n",
    "    sys.path.append(module_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 103,
   "id": "04681df3",
   "metadata": {},
   "outputs": [],
   "source": [
    "from hydra import initialize, initialize_config_module, initialize_config_dir, compose\n",
    "import tasks\n",
    "import numpy as np\n",
    "from data.speech_nonspeech_subject_data import NonLinguisticSubjectData, SentenceOnsetSubjectData\n",
    "from sklearn import linear_model\n",
    "from sklearn.model_selection import cross_val_score, cross_validate\n",
    "import scipy\n",
    "from sklearn.pipeline import make_pipeline\n",
    "from sklearn import preprocessing\n",
    "from sklearn.metrics import roc_auc_score"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "3bc02801",
   "metadata": {},
   "outputs": [],
   "source": [
    "# +exp=timed_tests ++exp.runner.num_workers=0 +data=decoding_base +model=classification_regression_model ++exp.runner.name=\"regression_runner_test_set\" +task=timed_tests +criterion=empty_criterion ++data.interval_duration=3.0 ++data.name=\"word_onset_regression\" +test=held_out_subjects ++test.test_split_path=/storage/czw/seeg_decoding/data/all_trials.json ++test.test_electrodes_path=\"None\" ++data.movie_transcripts_dir=\"/storage/czw/seeg_decoding/updated_word_features\" ++task.windows_start=-1 ++task.windows_end=1 ++task.window_duration=0.250 ++task.window_step=0.1 ++test.out_dir=\"/storage/czw/seeg_decoding/outputs/reg_test_set_all_trials_debug/\" ++data.saved_data_split=/storage/czw/seeg_decoding/saved_data_splits"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "33e92919",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'exp': {'runner': {'name': 'regression_runner_test_set', 'num_workers': 0}}, 'data': {'name': 'word_onset_regression', 'high_gamma': False, 'samp_frequency': 2048, 'raw_brain_data_dir': '/storage/datasets/neuroscience/ecog', 'subject': 'sub3', 'brain_runs': ['trial000', 'trial001'], 'electrodes': ['T1b2'], 'rereference': 'None', 'normalization': False, 'despike': False, 'delta': '???', 'duration': '???', 'words': [], 'cached_transcript_aligns': '/storage/czw/LanguageEcog/semantics/saved_aligns', 'test_split': 0.2, 'preprocessor': 'stft', 'interval_duration': 3.0, 'movie_transcripts_dir': '/storage/czw/seeg_decoding/updated_word_features', 'saved_data_split': '/storage/czw/seeg_decoding/saved_data_splits'}, 'model': {'name': 'classification_regression_model'}, 'task': {'name': 'classification_regression_task', 'windows_start': -1, 'windows_end': 1, 'window_duration': 0.25, 'num_repeat': 10, 'window_step': 0.1}, 'criterion': {'name': 'empty_criterion'}, 'test': {'test_split_path': '/storage/czw/seeg_decoding/data/all_trials.json', 'test_electrodes_path': 'None', 'out_dir': '/storage/czw/seeg_decoding/outputs/reg_test_set_all_trials_debug/'}}\n"
     ]
    }
   ],
   "source": [
    "# duration=1.5\n",
    "# delta=-0.5\n",
    "data_delta = -0.1\n",
    "data_duration = 0.7\n",
    "\n",
    "with initialize(version_base=None, config_path=\"../conf\"):\n",
    "    cfg = compose(overrides=['+exp=timed_tests',\n",
    "                             '++exp.runner.num_workers=0',\n",
    "                             '+data=decoding_base',\n",
    "                             '+model=classification_regression_model',\n",
    "                             '++exp.runner.name=regression_runner_test_set',\n",
    "                             '+task=timed_tests',\n",
    "                             '+criterion=empty_criterion',\n",
    "                             '++data.interval_duration=3.0',\n",
    "                             '++data.name=word_onset_regression',\n",
    "                             '+test=held_out_subjects', \n",
    "                             '++test.test_split_path=/storage/czw/seeg_decoding/data/all_trials.json',\n",
    "                             '++test.test_electrodes_path=None',\n",
    "                             '++data.movie_transcripts_dir=/storage/czw/seeg_decoding/updated_word_features',\n",
    "                             '++task.windows_start=-1',\n",
    "                             '++task.windows_end=1',\n",
    "                             '++task.window_duration=0.250',\n",
    "                             '++task.window_step=0.1',\n",
    "                             '++test.out_dir=/storage/czw/seeg_decoding/outputs/reg_test_set_all_trials_debug/',\n",
    "                             '++data.saved_data_split=/storage/czw/seeg_decoding/saved_data_splits'])\n",
    "    print(cfg)\n",
    "#    ++data.name=\"sentence_position_finetuning\"\n",
    "data_cfg = cfg.data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "c942b6e1",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_cfg[\"subject\"] = \"sub3\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "00ea391f",
   "metadata": {},
   "outputs": [],
   "source": [
    "windows_start = cfg.task.windows_start\n",
    "windows_end = cfg.task.windows_end\n",
    "all_windows_duration = windows_end - windows_start\n",
    "window_duration = cfg.task.window_duration\n",
    "window_step = cfg.task.window_step\n",
    "assert windows_end > windows_start\n",
    "assert (windows_end - windows_start) > window_duration\n",
    "window_starts = np.arange(windows_start, windows_end, window_step)\n",
    "sr = 2048\n",
    "idx_starts = [int((w - windows_start)*sr) for w in window_starts]\n",
    "idx_duration = int(window_duration*2048)\n",
    "\n",
    "subject_test_results = {}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "ac027869",
   "metadata": {},
   "outputs": [],
   "source": [
    "idx_starts = [int((w - windows_start)*sr) for w in window_starts]\n",
    "idx_duration = int(window_duration*2048)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "59a73c0e",
   "metadata": {},
   "outputs": [],
   "source": [
    "e = \"T1cIe11\"\n",
    "brain_runs = [\"trial000\",\"trial001\", \"trial002\"]\n",
    "# data_cfg\n",
    "\n",
    "data_cfg_copy = data_cfg.copy()\n",
    "data_cfg_copy.duration = all_windows_duration\n",
    "data_cfg_copy.delta = windows_start\n",
    "data_cfg_copy.electrodes = [e]\n",
    "data_cfg_copy.brain_runs = brain_runs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "7959e480",
   "metadata": {},
   "outputs": [],
   "source": [
    "cfg.data = data_cfg_copy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "1a8d00e4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# cfg.data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "f26cbc0a",
   "metadata": {},
   "outputs": [],
   "source": [
    "task = tasks.setup_task(cfg.task)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "4724a9c2",
   "metadata": {},
   "outputs": [],
   "source": [
    "s = SentenceOnsetSubjectData(data_cfg_copy)\n",
    "cached_word_df = s.labels\n",
    "all_data = s.neural_data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "50e421f6",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'name': 'word_onset_regression', 'high_gamma': False, 'samp_frequency': 2048, 'raw_brain_data_dir': '/storage/datasets/neuroscience/ecog', 'subject': 'sub3', 'brain_runs': ['trial000', 'trial001', 'trial002'], 'electrodes': ['T1cIe11'], 'rereference': 'None', 'normalization': False, 'despike': False, 'delta': -1, 'duration': 2, 'words': [], 'cached_transcript_aligns': '/storage/czw/LanguageEcog/semantics/saved_aligns', 'test_split': 0.2, 'preprocessor': 'stft', 'interval_duration': 3.0, 'movie_transcripts_dir': '/storage/czw/seeg_decoding/updated_word_features', 'saved_data_split': '/storage/czw/seeg_decoding/saved_data_splits'}"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data_cfg_copy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "9b1efa74",
   "metadata": {},
   "outputs": [],
   "source": [
    "start = idx_starts[10]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "ce092fdb",
   "metadata": {},
   "outputs": [],
   "source": [
    "cached_seeg_data = all_data[:,:,start:start+idx_duration]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "1dd83e69",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "7196\n",
      "7196\n"
     ]
    }
   ],
   "source": [
    "task.load_datasets(cfg.data, cached_seeg_data=cached_seeg_data, cached_word_df=cached_word_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "a1aa336c",
   "metadata": {},
   "outputs": [],
   "source": [
    "X_train, y_train = task.train_X, task.train_y\n",
    "X_test, y_test = task.test_X, task.test_y\n",
    "# X_val, y_val = task.val_X, task.val_y\n",
    "\n",
    "X_train = scipy.signal.decimate(X_train, 10, axis=-1)\n",
    "X_test = scipy.signal.decimate(X_test, 10, axis=-1)\n",
    "#X_val = scipy.signal.decimate(X_val, 10, axis=-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 99,
   "id": "69f49dda",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = linear_model.LogisticRegression(random_state=0, max_iter=500)\n",
    "#scaler = preprocessing.StandardScaler().fit(X_train)\n",
    "#X_train = scaler.transform(X_train)\n",
    "#model.fit(X_train, y_train)\n",
    "\n",
    "clf = make_pipeline(preprocessing.StandardScaler(), model)\n",
    "#self.task.cfg.num_repeat\n",
    "#cv_score = cross_val_score(clf, X_train, y_train, cv=5, scoring=\"roc_auc\")\n",
    "X = task.dataset.seeg_data\n",
    "X = scipy.signal.decimate(X, 10, axis=-1)\n",
    "y = np.array(task.dataset.label)\n",
    "\n",
    "from sklearn.model_selection import StratifiedShuffleSplit\n",
    "cv = StratifiedShuffleSplit(5, random_state=0)\n",
    "cv_score = cross_val_score(clf, X, y, cv=cv, scoring=\"roc_auc\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 100,
   "id": "1fde4730",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0.77162037, 0.74760031, 0.76563272, 0.77395062, 0.76824846])"
      ]
     },
     "execution_count": 100,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "cv_score"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 94,
   "id": "b434c2b0",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(7196, 52)\n"
     ]
    }
   ],
   "source": [
    "print(X.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 95,
   "id": "e40887a7",
   "metadata": {},
   "outputs": [],
   "source": [
    "import random\n",
    "random.seed(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 105,
   "id": "7b49472f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.7716203703703703\n",
      "0.7476003086419754\n",
      "0.7656327160493829\n",
      "0.7739506172839505\n",
      "0.7682484567901235\n"
     ]
    }
   ],
   "source": [
    "for train_idxs, test_idxs in cv.split(X,y):\n",
    "#     print(train_idxs[:10])\n",
    "    model = linear_model.LogisticRegression(random_state=0, max_iter=500)\n",
    "    model = make_pipeline(preprocessing.StandardScaler(), model)\n",
    "\n",
    "    X_train = X[train_idxs]\n",
    "    y_train = y[train_idxs]\n",
    "    model.fit(X_train, y_train)\n",
    "\n",
    "    X_test = X[test_idxs]\n",
    "    X_test = scaler.transform(X_test)\n",
    "    y_test = y[test_idxs]\n",
    "#     print(model.score(X_test, y_test, scoring=\"roc_auc\"))\n",
    "    print(roc_auc_score(y_test, model.predict_proba(X_test)[:, 1]))\n",
    "#     break\n",
    "#get roc_auc of model TODO"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "id": "53dd93ec",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/storage/czw/anaconda3/envs/sss/lib/python3.7/site-packages/sklearn/linear_model/_logistic.py:818: ConvergenceWarning: lbfgs failed to converge (status=1):\n",
      "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n",
      "\n",
      "Increase the number of iterations (max_iter) or scale the data as shown in:\n",
      "    https://scikit-learn.org/stable/modules/preprocessing.html\n",
      "Please also refer to the documentation for alternative solver options:\n",
      "    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n",
      "  extra_warning_msg=_LOGISTIC_SOLVER_CONVERGENCE_MSG,\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "0.6998332406892718"
      ]
     },
     "execution_count": 53,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model = linear_model.LogisticRegression(random_state=0, max_iter=500)\n",
    "scaler = preprocessing.StandardScaler().fit(X)\n",
    "model.fit(X, y)\n",
    "model.score(X, y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "id": "ff91d0c8",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(6476, 52)"
      ]
     },
     "execution_count": 38,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X_train.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "51b1c0fc",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "LogisticRegression(max_iter=500, random_state=0)"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model = linear_model.LogisticRegression(random_state=0, max_iter=500)\n",
    "scaler = preprocessing.StandardScaler().fit(X_train)\n",
    "X_train = scaler.transform(X_train)\n",
    "model.fit(X_train, y_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "01d00b71",
   "metadata": {},
   "outputs": [],
   "source": [
    "X_test = scaler.transform(X_test)\n",
    "test_score = model.score(X_test, y_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "fa6048c2",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.6736111111111112"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_score"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "fca4ca85",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(7196, 52)"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "712bf9c5",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'name': 'word_onset_regression', 'high_gamma': False, 'samp_frequency': 2048, 'raw_brain_data_dir': '/storage/datasets/neuroscience/ecog', 'subject': 'sub3', 'brain_runs': ['trial000', 'trial001', 'trial002'], 'electrodes': ['T1cIe11'], 'rereference': 'None', 'normalization': False, 'despike': False, 'delta': -1, 'duration': 2, 'words': [], 'cached_transcript_aligns': '/storage/czw/LanguageEcog/semantics/saved_aligns', 'test_split': 0.2, 'preprocessor': 'stft', 'interval_duration': 3.0, 'movie_transcripts_dir': '/storage/czw/seeg_decoding/updated_word_features', 'saved_data_split': '/storage/czw/seeg_decoding/saved_data_splits'}"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "task.dataset.cfg"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2636d27d",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
