{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import fnmatch, os, sys\n",
    "sys.path.append('D:\\\\OneDrive\\\\codes\\\\xds')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Loading data from an xds format file"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Here are the files under this folder: \n",
      "\n",
      "['20190815_Greyson_Key_001.mat' '20190815_Greyson_Key_002.mat'\n",
      " '20190911_Greyson_Key_003.mat' '20190911_Greyson_Key_004.mat'\n",
      " '20190911_Greyson_PG_001.mat' '20190911_Greyson_PG_002.mat']\n"
     ]
    }
   ],
   "source": [
    "from xds import lab_data, list_to_nparray, smooth_binned_spikes\n",
    "\n",
    "base_path = \"D:/OneDrive/data/lab_data/Greyson_grasping/\"\n",
    "file_list = fnmatch.filter(os.listdir(base_path), \"*.mat\")\n",
    "file_list = np.sort(file_list)\n",
    "print(\"Here are the files under this folder: \\n\")\n",
    "print(file_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "D:/OneDrive/data/lab_data/Greyson_grasping/20190815_Greyson_Key_001.mat\n",
      "Monkey: Greyson\n",
      "Task: multi_gadget\n",
      "Collected on 2019/8/15 20:9:56.486 \n",
      "Raw file name is 20190815_Greyson_Key_001\n",
      "The array is in M1\n",
      "There are 96 neural channels\n",
      "Sorted? 0\n",
      "There are 12 EMG channels\n",
      "Current bin width is 0.0010 seconds\n",
      "The name of each EMG channel:\n",
      "EMG_FCR\n",
      "EMG_FDS1\n",
      "EMG_FDP3\n",
      "EMG_PT\n",
      "EMG_APB\n",
      "EMG_FPB\n",
      "EMG_LUM\n",
      "EMG_1DI\n",
      "EMG_EPL\n",
      "EMG_SUP\n",
      "EMG_ECU\n",
      "EMG_EDC1\n",
      "The dataset lasts 900.0090 seconds\n",
      "There are 126 trials\n",
      "In 113 trials the monkey got reward\n",
      "The new bin width is 0.0500 s\n"
     ]
    }
   ],
   "source": [
    "file_number = 0\n",
    "bin_size = 0.05\n",
    "dataset = lab_data(base_path, file_list[file_number])\n",
    "dataset.update_bin_data(bin_size)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Dividing training and testing sets and formatting data for decoder training\n",
    "Without removing samples between trials"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "There are 8992 training samples. \n",
      "There are 8992 testing samples. \n"
     ]
    }
   ],
   "source": [
    "from wiener_filter import format_data\n",
    "n_lags = 8\n",
    "train_test_ratio = 0.5 # It means 60% data in this set for training, and 40% for testing\n",
    "total_sample = np.size(dataset.spike_counts, 0)\n",
    "train_x = dataset.spike_counts[:int(train_test_ratio*total_sample), :] \n",
    "train_y = dataset.EMG[:int(train_test_ratio*total_sample), :]\n",
    "test_x = dataset.spike_counts[int(train_test_ratio*total_sample): , :] \n",
    "test_y = dataset.EMG[int(train_test_ratio*total_sample): , :]\n",
    "train_x, train_y = format_data(train_x, train_y, n_lags)\n",
    "test_x, test_y = format_data(test_x, test_y, n_lags)\n",
    "print(\"There are %d training samples. \" % np.size(train_x, 0))\n",
    "print(\"There are %d testing samples. \" % np.size(test_x, 0))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Training a linear Wiener filter based decoder without L2 regularization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "VAF for linear Wiener filter is 0.760\n"
     ]
    }
   ],
   "source": [
    "from wiener_filter import train_wiener_filter, test_wiener_filter, vaf\n",
    "H_reg = train_wiener_filter(train_x, train_y, l2 = 0)\n",
    "test_y_pred = test_wiener_filter(test_x, H_reg)\n",
    "print(\"VAF for linear Wiener filter is %.3f\" % vaf(test_y, test_y_pred))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Training a linear Wiener filter based decoder with L2 regularization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sweeping ridge regularization using CV decoding on train data\n",
      "Testing c= 10.0\n",
      "Testing c= 16.237767391887218\n",
      "Testing c= 26.366508987303583\n",
      "Testing c= 42.81332398719393\n",
      "Testing c= 69.51927961775606\n",
      "Testing c= 112.88378916846884\n",
      "Testing c= 183.29807108324357\n",
      "Testing c= 297.63514416313194\n",
      "Testing c= 483.2930238571752\n",
      "Testing c= 784.7599703514607\n",
      "Testing c= 1274.2749857031336\n",
      "Testing c= 2069.13808111479\n",
      "Testing c= 3359.818286283781\n",
      "Testing c= 5455.594781168515\n",
      "Testing c= 8858.667904100823\n",
      "Testing c= 14384.498882876629\n",
      "Testing c= 23357.21469090121\n",
      "Testing c= 37926.90190732246\n",
      "Testing c= 61584.82110660255\n",
      "Testing c= 100000.0\n",
      "VAF for linear Wiener filter is 0.797\n"
     ]
    }
   ],
   "source": [
    "H_reg = train_wiener_filter(train_x, train_y, l2 = 1)\n",
    "test_y_pred = test_wiener_filter(test_x, H_reg)\n",
    "print(\"VAF for linear Wiener filter is %.3f\" % vaf(test_y, test_y_pred))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Training a nonlinear Wiener filter based decoder without L2 regularization\n",
    "Sometimes nonlinear version performs better, but not necessarily"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "VAF for nonlinear Wiener filter is 0.768\n"
     ]
    }
   ],
   "source": [
    "from wiener_filter import train_nonlinear_wiener_filter, test_nonlinear_wiener_filter\n",
    "H_reg, res_lsq = train_nonlinear_wiener_filter(train_x, train_y, l2 = 0)\n",
    "test_y_pred = test_nonlinear_wiener_filter(test_x, H_reg, res_lsq)\n",
    "print(\"VAF for nonlinear Wiener filter is %.3f\" % vaf(test_y, test_y_pred))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Training a nonlinear Wiener filter based decoder with L2 regularization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sweeping ridge regularization using CV decoding on train data\n",
      "Testing c= 10.0\n",
      "Testing c= 16.237767391887218\n",
      "Testing c= 26.366508987303583\n",
      "Testing c= 42.81332398719393\n",
      "Testing c= 69.51927961775606\n",
      "Testing c= 112.88378916846884\n",
      "Testing c= 183.29807108324357\n",
      "Testing c= 297.63514416313194\n",
      "Testing c= 483.2930238571752\n",
      "Testing c= 784.7599703514607\n",
      "Testing c= 1274.2749857031336\n",
      "Testing c= 2069.13808111479\n",
      "Testing c= 3359.818286283781\n",
      "Testing c= 5455.594781168515\n",
      "Testing c= 8858.667904100823\n",
      "Testing c= 14384.498882876629\n",
      "Testing c= 23357.21469090121\n",
      "Testing c= 37926.90190732246\n",
      "Testing c= 61584.82110660255\n",
      "Testing c= 100000.0\n",
      "VAF for nonlinear Wiener filter is 0.793\n"
     ]
    }
   ],
   "source": [
    "from wiener_filter import train_nonlinear_wiener_filter, test_nonlinear_wiener_filter\n",
    "H_reg, res_lsq = train_nonlinear_wiener_filter(train_x, train_y, l2 = 1)\n",
    "test_y_pred = test_nonlinear_wiener_filter(test_x, H_reg, res_lsq)\n",
    "print(\"VAF for nonlinear Wiener filter is %.3f\" % vaf(test_y, test_y_pred))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
