{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c264d7e4",
   "metadata": {},
   "outputs": [],
   "source": [
    "path_to_data = 'data/dorschky2024'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f7f08508-3893-4c6b-80bd-348250ac0762",
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib qt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "786292c0-3cb2-4201-bdfb-8158299207a1",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Imports\n",
    "import os\n",
    "import subprocess\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import torch\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "os.chdir('..')\n",
    "\n",
    "from matplotlib.animation import FuncAnimation\n",
    "from IPython.display import HTML\n",
    "\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "device = 'cpu'"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0c91bca3",
   "metadata": {},
   "source": [
    "# Functions to segment bouts from Dorschky's Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "593bc980-24c1-4bcb-87a8-1ddc809c34c3",
   "metadata": {},
   "outputs": [],
   "source": [
    "#create a moving avg function for a 1D-array:\n",
    "def moving_avg(arr, window_size):\n",
    "    return np.convolve(arr, np.hamming(window_size)/window_size, mode='valid')\n",
    "  \n",
    "# and moving std: \n",
    "def moving_std(arr, window_size):\n",
    "    return np.sqrt(moving_avg(arr**2, window_size) - moving_avg(arr, window_size)**2)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1324ab25",
   "metadata": {},
   "outputs": [],
   "source": [
    "def check_if_inbetween_triggers(candidates_standing, candidates_turning, df):\n",
    "    standing_ = []\n",
    "    turning_ = []\n",
    "    for i in range(7,13): # All trial triggers\n",
    "        # get the trigger times\n",
    "        start_ = df1[df1.TRIGGER == i].index[0] / 1000\n",
    "        end_ = df1[df1.TRIGGER == i].index[-1] / 1000\n",
    "        # append the standing and turning candidates that are inbetween the triggers and append them to the lists as a flat list\n",
    "        standing_+= [*candidates_standing[np.logical_and(candidates_standing > start_ - 5, candidates_standing < end_)]]\n",
    "        turning_+= [*candidates_turning[np.logical_and(candidates_turning > start_, candidates_turning < end_ + 10)]]\n",
    "\n",
    "    return np.ravel(standing_), np.ravel(turning_)\n",
    "\n",
    "# write a function to select standing candidates where the next candidate (chronologically) is a turning candidate\n",
    "def find_consecutive_candidates(candidates_standing, candidates_turning):\n",
    "    standing_candidates = []\n",
    "    turning_candidates = []\n",
    "    for i in range(len(candidates_standing)-1):\n",
    "        # check if there is a turning candidate between the current standing candidate and the next standing candidate, and append both to the lists\n",
    "        if np.any(np.logical_and(candidates_turning > candidates_standing[i], candidates_turning < candidates_standing[i+1])):\n",
    "            standing_candidates.append(candidates_standing[i])\n",
    "            turning_candidates.append(candidates_turning[np.logical_and(candidates_turning > candidates_standing[i], candidates_turning < candidates_standing[i+1])][0])\n",
    "    return np.array(standing_candidates), np.array(turning_candidates)\n",
    "\n",
    "def filter_by_length(candidates_standing, candidates_turning):\n",
    "    avg_length = np.mean(candidates_turning - candidates_standing)\n",
    "    std_length = np.std(candidates_turning - candidates_standing)\n",
    "    print(f\"Lower bound: {avg_length - 1.5*std_length}, Upper bound: {avg_length + 1.5*std_length}\")\n",
    "    # filter out the candidates where the length is not within 2 std of the mean\n",
    "    mask = np.logical_and(candidates_turning - candidates_standing > avg_length - 1.5*std_length, candidates_turning - candidates_standing < avg_length + 1.5*std_length)\n",
    "    return candidates_standing[mask], candidates_turning[mask]\n",
    "\n",
    "def filter_running_short_moves(candidates_standing, candidates_turning, m_std):\n",
    "    # For walking & S01, all bouts have been detected correctly by the algorithm.\n",
    "    # For running, the algorithm has detected some bouts that are too short and are the walk-back route.\n",
    "    # To filter these we look if the following condition is met:\n",
    "    # The bout is at a trigger point >= 10 (minus 5 seconds)\n",
    "    # The bout does not contain a moving std > 2.1 (Which gait doesn't usually have)\n",
    "\n",
    "    # get the trigger times\n",
    "    trigger_start_roi = df1[df1.TRIGGER == 10].index[0] / 1000 # We don't need to take the -5 seconds into account here, as the pre-trigger is guaranteed to be included in the candidates\n",
    "    mask = candidates_standing <= trigger_start_roi # So that all candidates are before the trigger are included automatically\n",
    "\n",
    "    # Get the first item in the mask that is False \n",
    "    start = np.where(mask == False)[0][0]\n",
    "    for i in range(start, len(candidates_standing)):\n",
    "        # Check if the moving std is > 2\n",
    "        i_0 = int(candidates_standing[i] * 1000)\n",
    "        i_1 = int(candidates_turning[i] * 1000)\n",
    "        if np.any(m_std[i_0:i_1] > 2.1):\n",
    "            mask[i] = True\n",
    "\n",
    "    return candidates_standing[mask], candidates_turning[mask]\n",
    "\n",
    "    \n",
    "    \n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0324c4fc-54fc-4334-8a64-319f61abdce2",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_candidates(subject, plot = True):\n",
    "    if plot:\n",
    "        fig = plt.figure(figsize = (30,10))\n",
    "    #plt.plot(df1.FEMUR_L_ACC_Y)\n",
    "\n",
    "    t_s = np.arange(0,len(df1))/1000\n",
    "    if plot:\n",
    "        plt.plot(t_s,df1.FOOT_R_GYRO_Z)\n",
    "        plt.plot(t_s,df1.TRIGGER/2)\n",
    "\n",
    "    window_size_h_a = 256\n",
    "    m_a = moving_avg(df1.PELVIS_GYRO_Y,window_size_h_a*2)\n",
    "    if plot:\n",
    "        plt.plot(t_s[window_size_h_a:-window_size_h_a+1],m_a)\n",
    "    window_size_h = 192\n",
    "    m_std = moving_std(df1.FOOT_L_GYRO_Z,window_size_h*2)\n",
    "    if plot:\n",
    "        plt.plot(t_s[window_size_h:-window_size_h+1],m_std)\n",
    "\n",
    "    #plt.xlim(600,800)\n",
    "    if plot:\n",
    "        plt.ylim(-15,15)\n",
    "        plt.grid()\n",
    "        for i in range(len(df[df.TRIGGER > 0])-1):\n",
    "            end =  (df.iloc[(df[df.TIME==0].index-1)]).sort_index().TIME.iat[i]\n",
    "            start = (df.iloc[(df[df.TIME==0].index)]).sort_index().TIME.iat[i]\n",
    "            ttime =   df[df.TRIGGER > 6].TIME.iat[i]\n",
    "            trigger = df1[df1.TRIGGER > 6].index[i] / 1000\n",
    "            plt.fill_between([trigger+start-ttime, trigger+end-ttime],[0,0],[15,15],color='k',alpha = 0.2)\n",
    "\n",
    "    standing_threshold = 0.03\n",
    "    turning_threshold = -1\n",
    "    if subject in [7,10]:\n",
    "        turning_threshold = -0.8 # That person turned rather slowly\n",
    "    candidates_standing = t_s[window_size_h:-window_size_h+1][moving_std(df1.FOOT_L_GYRO_Z,window_size_h*2)<standing_threshold]\n",
    "    candidates_turning = t_s[window_size_h_a:-window_size_h_a+1][moving_avg(df1.PELVIS_GYRO_Y,window_size_h_a*2)<turning_threshold]\n",
    "    # Also check turning in the other direction\n",
    "    candidates_turning = np.append(candidates_turning, t_s[window_size_h_a:-window_size_h_a+1][moving_avg(df1.PELVIS_GYRO_Y,window_size_h_a*2)>-turning_threshold])\n",
    "    candidates_standing, candidates_turning = check_if_inbetween_triggers(candidates_standing, candidates_turning, df1)\n",
    "    candidates_standing, candidates_turning = find_consecutive_candidates(candidates_standing, candidates_turning)\n",
    "    candidates_standing, candidates_turning = filter_by_length(candidates_standing, candidates_turning)\n",
    "    candidates_standing, candidates_turning = filter_running_short_moves(candidates_standing, candidates_turning, m_std)\n",
    "    print(f\"Found {len(candidates_standing)} bouts with an average length of {np.mean(candidates_turning-candidates_standing-2)} seconds\")\n",
    "    if plot:\n",
    "        plt.plot(candidates_standing,np.ones_like(candidates_standing),'bx')\n",
    "        plt.plot(candidates_turning,np.ones_like(candidates_turning)*1.1,'rx')\n",
    "        for i in range(len(candidates_standing)):\n",
    "            plt.fill_between([candidates_standing[i]+1,candidates_turning[i]-1],[0,0],[15,15],color='r',alpha = 0.2)\n",
    "        plt.title(f\"Subject {subject}\")\n",
    "    return candidates_standing, candidates_turning"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7c36da4c",
   "metadata": {},
   "source": [
    "## Get the speed of a foot for a trial"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ef172dc7",
   "metadata": {},
   "outputs": [],
   "source": [
    "from gaitmap.trajectory_reconstruction import RegionLevelTrajectory, RtsKalman\n",
    "def get_speed_foot_x(imu_dataframe):\n",
    "    trajectory_method = RtsKalman()\n",
    "    # We set the ori and pos methods explicitly to `None` here to silence a warning indicating potential user error.\n",
    "    # In general, when `trajectory_method` is provided `ori_method` and `pos_method` are ignored.\n",
    "    trajectory_full = RegionLevelTrajectory(trajectory_method=trajectory_method, ori_method=None, pos_method=None)\n",
    "\n",
    "    data_ = imu_dataframe[['FOOT_R_GYRO_X','FOOT_R_GYRO_Y','FOOT_R_GYRO_Z','FOOT_R_ACC_X','FOOT_R_ACC_Y','FOOT_R_ACC_Z']].copy()\n",
    "    # rename the columns to gyr_x, gyr_y, gyr_z, acc_x, acc_y, acc_z\n",
    "    data_.columns = ['gyr_x','gyr_y','gyr_z','acc_x','acc_y','acc_z']\n",
    "    #data_.acc_z = -data_.acc_z\n",
    "    data_.gyr_z = data_.gyr_z * 180 / np.pi\n",
    "    data_.gyr_y = data_.gyr_y * 180 / np.pi\n",
    "    data_.gyr_x = data_.gyr_x * 180 / np.pi\n",
    "\n",
    "    dummy_regions_list = pd.DataFrame([[0, len(data_)]], columns=[\"start\", \"end\"]).rename_axis(\"gs_id\")\n",
    "    sampling_frequency_hz = 1000\n",
    "    trajectory_full.estimate(data=data_.reset_index(), regions_of_interest=dummy_regions_list, sampling_rate_hz=sampling_frequency_hz)\n",
    "    first_region_position = trajectory_full.position_.loc[0]\n",
    "    speed_r = np.sqrt((first_region_position.pos_x.diff()*1000)**2 + (first_region_position.pos_y.diff()*1000)**2)\n",
    "    \n",
    "    # select the position of the first (and only) gait sequence\n",
    "    first_region_position = trajectory_full.position_.loc[0]\n",
    "\n",
    "    first_region_position.plot()\n",
    "    plt.title(\"Left Foot Trajectory per axis\")\n",
    "    plt.xlabel(\"sample\")\n",
    "    plt.ylabel(\"position [m]\")\n",
    "    plt.show()\n",
    "    \n",
    "    # select the orientation of the first (and only) gait sequence\n",
    "    first_region_orientation = trajectory_full.orientation_.loc[0]\n",
    "    \n",
    "    first_region_orientation.plot()\n",
    "    plt.title(\"Left Foot Orientation per axis\")\n",
    "    plt.xlabel(\"sample\")\n",
    "    plt.ylabel(\"orientation [a.u.]\")\n",
    "    plt.show()\n",
    "\n",
    "\n",
    "    data_ = imu_dataframe[['FOOT_L_GYRO_X','FOOT_L_GYRO_Y','FOOT_L_GYRO_Z','FOOT_L_ACC_X','FOOT_L_ACC_Y','FOOT_L_ACC_Z']].copy()\n",
    "    # rename the columns to gyr_x, gyr_y, gyr_z, acc_x, acc_y, acc_z\n",
    "    data_.columns = ['gyr_x','gyr_y','gyr_z','acc_x','acc_y','acc_z']\n",
    "    #data_.acc_z = -data_.acc_z\n",
    "    data_.gyr_z = data_.gyr_z * 180 / np.pi\n",
    "    data_.gyr_y = data_.gyr_y * 180 / np.pi\n",
    "    data_.gyr_x = data_.gyr_x * 180 / np.pi\n",
    "\n",
    "    dummy_regions_list = pd.DataFrame([[0, len(data_)]], columns=[\"start\", \"end\"]).rename_axis(\"gs_id\")\n",
    "    sampling_frequency_hz = 1000\n",
    "    trajectory_full.estimate(data=data_.reset_index(), regions_of_interest=dummy_regions_list, sampling_rate_hz=sampling_frequency_hz)\n",
    "    first_region_position = trajectory_full.position_.loc[0]\n",
    "    speed_l = np.sqrt((first_region_position.pos_x.diff()*1000)**2 + (first_region_position.pos_y.diff()*1000)**2)\n",
    "\n",
    "    return speed_r.to_numpy(), speed_l.to_numpy()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e2b98c76",
   "metadata": {},
   "source": [
    "# Build the necessary dataframes"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "59f3d5da",
   "metadata": {},
   "source": [
    "## IMU Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8b7d655d-af81-4bc5-9316-4a074ec8ec7d",
   "metadata": {},
   "outputs": [],
   "source": [
    "import scipy.signal\n",
    "imu_dataframe = pd.DataFrame(columns = ['subject','trial','data','hastrigger','triggerno','triggeridx','speed_r','speed_l'])\n",
    "signal_keys = ['PELVIS','FEMUR_R','TIBIA_R','FOOT_R','FEMUR_L','TIBIA_L','FOOT_L']\n",
    "signal_keys = [f'{key}_{signal}' for key in signal_keys for signal in ['ACC_X','ACC_Y','GYRO_Z']]\n",
    "for i in range(3,4):\n",
    "    df1 = pd.read_parquet(f'{path_to_data}/P{str(i).zfill(2)}_IMU.parquet')\n",
    "    df = pd.read_parquet(f'{path_to_data}/P{str(i).zfill(2)}_OMC.parquet')\n",
    "    speed_r, speed_l = get_speed_foot_x(df1)\n",
    "    dfnew = pd.DataFrame()\n",
    "    starts_, ends_ = get_candidates(i, plot = True)\n",
    "    for signal in signal_keys:\n",
    "        dfnew[signal] = scipy.signal.decimate(df1[signal], 10)\n",
    "    speed_r = scipy.signal.decimate(speed_r[1:], 10)\n",
    "    speed_l = scipy.signal.decimate(speed_l[1:], 10)\n",
    "\n",
    "    for trial_idx, trial in enumerate(zip(starts_, ends_)):\n",
    "        start_idx = int(100*trial[0])\n",
    "        end_idx = int(100*trial[1])\n",
    "        trial_data = np.zeros((len(signal_keys), end_idx-start_idx))\n",
    "        for idx, signal in enumerate(signal_keys):\n",
    "            trial_data[idx] = dfnew[signal][start_idx:end_idx]\n",
    "        \n",
    "        # Count all triggers in the trial up to end_idx\n",
    "        trigger_hist = df1.iloc[:end_idx*10][df1.iloc[:end_idx*10].TRIGGER > 0]\n",
    "        hastrigger = df1.iloc[:end_idx*10][df1.iloc[:end_idx*10].TRIGGER > 0].TRIGGER.max()\n",
    "        trigger_no = len(trigger_hist[trigger_hist.TRIGGER == hastrigger])\n",
    "\n",
    "        # Find the trigger index relative to the trial start\n",
    "        trigger_idx = trigger_hist[trigger_hist.TRIGGER == hastrigger].index[-1] / 1000 - trial[0]\n",
    "\n",
    "        # get the speed of the feet\n",
    "        speed_r_ = speed_r[start_idx:end_idx]\n",
    "        speed_l_ = speed_l[start_idx:end_idx]\n",
    "\n",
    "        imu_dataframe.loc[len(imu_dataframe)] = [i, trial_idx, torch.from_numpy(trial_data), hastrigger, trigger_no, trigger_idx, torch.from_numpy(speed_r_.copy()), torch.from_numpy(speed_l_.copy())]\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f023f6a5",
   "metadata": {},
   "outputs": [],
   "source": [
    "## One more filter step: Remove all trials where duration is smaller than 4.5 seconds\n",
    "imu_dataframe = imu_dataframe[imu_dataframe.data.apply(lambda x: x.shape[1]) > 450]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ed55d4e3",
   "metadata": {},
   "source": [
    "## Save "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "61320165-e93a-45fe-8950-c18af4dc3de7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Pickle the imu dataframe\n",
    "import pickle\n",
    "with open('data/dorschky2024/dorschky_sequences.pkl','wb') as output:\n",
    "    pickle.dump(imu_dataframe, output)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pinn",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
