{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "from joblib import Parallel, delayed\n",
    "from tqdm import tqdm\n",
    "\n",
    "def cross_correlation(series1, series2, lag=0):\n",
    "    if lag < 0:\n",
    "        raise ValueError(\"Lag must be non-negative\")\n",
    "    # Align the series based on lag\n",
    "    if lag == 0:\n",
    "        s1 = series1\n",
    "        s2 = series2\n",
    "    else:\n",
    "        s1 = series1[lag:]\n",
    "        s2 = series2[:-lag]\n",
    "    # Drop missing values from both series\n",
    "    valid = s1.notna() & s2.notna()\n",
    "    s1 = s1[valid]\n",
    "    s2 = s2[valid]\n",
    "    \n",
    "    if len(s1) < 2 or len(s2) < 2:\n",
    "        return np.nan\n",
    "    \n",
    "    # Convert to numpy arrays and compute standard deviations\n",
    "    x = s1.to_numpy()\n",
    "    y = s2.to_numpy()\n",
    "    std_x = np.std(x, ddof=1)\n",
    "    std_y = np.std(y, ddof=1)\n",
    "    \n",
    "    if std_x == 0 or std_y == 0:\n",
    "        return np.nan\n",
    "    \n",
    "    # Compute covariance with Bessel's correction\n",
    "    cov = np.cov(x, y, ddof=1)[0, 1]\n",
    "    return cov / (std_x * std_y)\n",
    "\n",
    "def process_trajectory(traj, group, other_features, lags):\n",
    "    \"\"\"\n",
    "    Process a single patient trajectory. For each feature and each lag,\n",
    "    compute the cross-correlation with 'o:SOFA' after dropping missing values.\n",
    "    \"\"\"\n",
    "    local_results = {feature: {lag: [] for lag in lags} for feature in other_features}\n",
    "    group = group.sort_values('step')\n",
    "    sofa_series = group['o:SOFA']\n",
    "    \n",
    "    for feature in other_features:\n",
    "        feature_series = group[feature]\n",
    "        # Compute a valid mask once for this feature\n",
    "        valid = sofa_series.notna() & feature_series.notna()\n",
    "        sofa_valid = sofa_series[valid]\n",
    "        feature_valid = feature_series[valid]\n",
    "        \n",
    "        for lag in lags:\n",
    "            # Only compute if enough data points exist after alignment\n",
    "            if len(sofa_valid) > lag:\n",
    "                corr_value = cross_correlation(sofa_valid, feature_valid, lag)\n",
    "                if not np.isnan(corr_value):\n",
    "                    local_results[feature][lag].append(corr_value)\n",
    "    return local_results\n",
    "\n",
    "def merge_results(results, other_features, lags):\n",
    "    \"\"\"\n",
    "    Merge the results (lists of correlation values) from each trajectory.\n",
    "    \"\"\"\n",
    "    merged = {feature: {lag: [] for lag in lags} for feature in other_features}\n",
    "    for res in results:\n",
    "        for feature in other_features:\n",
    "            for lag in lags:\n",
    "                merged[feature][lag].extend(res[feature][lag])\n",
    "    return merged\n",
    "\n",
    "def main():\n",
    "    # Update the file path as necessary\n",
    "    csv_file = \"./sepsis_final_data_RAW_withTimes_continuous.csv\"\n",
    "    df = pd.read_csv(csv_file)\n",
    "    df = df.sort_values(['traj', 'step'])\n",
    "    \n",
    "    # Identify observation columns (those starting with \"o\")\n",
    "    obs_cols = [col for col in df.columns if col.startswith('o')]\n",
    "    if 'o:SOFA' not in obs_cols:\n",
    "        raise ValueError(\"Column 'o:SOFA' not found in dataset\")\n",
    "    \n",
    "    # List of features to compare with o:SOFA (excluding itself)\n",
    "    other_features = [col for col in obs_cols if col != 'o:SOFA']\n",
    "    \n",
    "    # Remove specific columns that are not required\n",
    "    for col in [\"o:max_dose_vaso\", \"o:input_4hourly\", \"o:age\", \"o:gender\", \"o:re_admission\"]:\n",
    "        if col in other_features:\n",
    "            other_features.remove(col)\n",
    "    \n",
    "    # Define the lags to compute\n",
    "    lags = [0, 1, 2, 3]\n",
    "    \n",
    "    # Group data by patient trajectory\n",
    "    grouped = list(df.groupby('traj'))\n",
    "    \n",
    "    # Process each trajectory in parallel\n",
    "    results = Parallel(n_jobs=-1)(\n",
    "        delayed(process_trajectory)(traj, group, other_features, lags)\n",
    "        for traj, group in tqdm(grouped, desc=\"Processing trajectories\")\n",
    "    )\n",
    "    \n",
    "    # Merge results from all trajectories\n",
    "    merged_results = merge_results(results, other_features, lags)\n",
    "    \n",
    "    # Average the correlations for each feature and lag across patients\n",
    "    avg_corr = {feature: {} for feature in other_features}\n",
    "    for feature in other_features:\n",
    "        for lag in lags:\n",
    "            if len(merged_results[feature][lag]) > 0:\n",
    "                avg_corr[feature][lag] = np.mean(merged_results[feature][lag])\n",
    "            else:\n",
    "                avg_corr[feature][lag] = np.nan\n",
    "    \n",
    "    # Convert the results into a DataFrame for easier viewing\n",
    "    avg_corr_df = pd.DataFrame(avg_corr).T\n",
    "    avg_corr_df = avg_corr_df[lags]  # Ensure columns are in lag order\n",
    "    print(avg_corr_df)\n",
    "    \n",
    "    return avg_corr_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing trajectories: 100%|██████████| 18913/18913 [03:20<00:00, 94.37it/s] \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "                            0         1         2         3\n",
      "o:mechvent           0.176696  0.264492  0.274567  0.257572\n",
      "o:Weight_kg         -0.005629 -0.023109 -0.026958 -0.021322\n",
      "o:GCS               -0.444028 -0.437736 -0.435731 -0.428769\n",
      "o:HR                 0.007500 -0.014189 -0.025463 -0.032236\n",
      "o:SysBP             -0.111810 -0.130255 -0.141662 -0.141074\n",
      "o:MeanBP            -0.138606 -0.155146 -0.167490 -0.171930\n",
      "o:DiaBP             -0.115162 -0.125050 -0.135155 -0.131484\n",
      "o:RR                -0.012927 -0.039910 -0.046301 -0.042172\n",
      "o:Temp_C             0.016114  0.009205  0.009203  0.012268\n",
      "o:FiO2_1             0.204245  0.216806  0.207921  0.200784\n",
      "o:Potassium          0.031623  0.021023  0.024423  0.020321\n",
      "o:Sodium            -0.034941 -0.031261 -0.029088 -0.015920\n",
      "o:Chloride          -0.060965 -0.016401 -0.001876 -0.004177\n",
      "o:Glucose            0.004725  0.022217  0.023451  0.020999\n",
      "o:Magnesium          0.003279 -0.010970 -0.012699 -0.005220\n",
      "o:Calcium            0.025881  0.007866 -0.005668 -0.007691\n",
      "o:Hb                 0.045729  0.035052  0.028563  0.018763\n",
      "o:WBC_count          0.030844  0.045673  0.049614  0.042391\n",
      "o:Platelets_count    0.003103 -0.002003 -0.004581 -0.009223\n",
      "o:PTT                0.023778  0.013674  0.010689  0.006544\n",
      "o:PT                 0.040016  0.041032  0.031050  0.023598\n",
      "o:Arterial_pH       -0.032900 -0.043484 -0.043374 -0.033732\n",
      "o:paO2              -0.389290 -0.380570 -0.393546 -0.403593\n",
      "o:paCO2              0.022805  0.026198  0.021412  0.021110\n",
      "o:Arterial_BE       -0.030143 -0.043799 -0.050597 -0.041792\n",
      "o:HCO3              -0.016435 -0.040391 -0.044894 -0.039489\n",
      "o:Arterial_lactate   0.057387  0.055688  0.049833  0.043093\n",
      "o:SIRS               0.026644  0.007781 -0.000514 -0.000390\n",
      "o:Shock_Index        0.097543  0.094589  0.094091  0.091613\n",
      "o:PaO2_FiO2         -0.429769 -0.431039 -0.439738 -0.447456\n",
      "o:cumulated_balance  0.041344  0.070664  0.089010  0.093522\n",
      "o:SpO2               0.066065  0.062148  0.052758  0.035425\n",
      "o:BUN                0.087242  0.070335  0.057005  0.043795\n",
      "o:Creatinine         0.118217  0.098378  0.080952  0.067665\n",
      "o:SGOT               0.112280  0.113244  0.109695  0.104732\n",
      "o:SGPT               0.085759  0.086083  0.085066  0.081814\n",
      "o:Total_bili         0.232405  0.225403  0.221775  0.223889\n",
      "o:INR                0.042939  0.042469  0.030767  0.023160\n",
      "o:input_total       -0.156176 -0.163395 -0.153063 -0.127062\n",
      "o:output_total      -0.160348 -0.180290 -0.180548 -0.162213\n",
      "o:output_4hourly    -0.314605 -0.226757 -0.177117 -0.164396\n"
     ]
    }
   ],
   "source": [
    "if __name__ == \"__main__\":\n",
    "    avg_corr_df = main()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Index(['o:mechvent', 'o:GCS', 'o:FiO2_1', 'o:paO2', 'o:PaO2_FiO2',\n",
      "       'o:Total_bili', 'o:output_4hourly'],\n",
      "      dtype='object')\n"
     ]
    }
   ],
   "source": [
    "selected_indices = avg_corr_df[(avg_corr_df.abs() > 0.2).any(axis=1)].index\n",
    "print(selected_indices)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing feature pairs: 100%|██████████| 238/238 [04:18<00:00,  1.09s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "    feature_1    feature_2  lag  correlation\n",
      "0  o:mechvent  o:Weight_kg    0     0.071849\n",
      "1  o:mechvent  o:Weight_kg    1     0.071639\n",
      "2  o:mechvent  o:Weight_kg    2     0.070113\n",
      "3  o:mechvent  o:Weight_kg    3     0.067919\n",
      "4  o:mechvent         o:HR    0     0.018215\n"
     ]
    }
   ],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "from joblib import Parallel, delayed\n",
    "from tqdm import tqdm\n",
    "\n",
    "def process_feature_pair(feature_1, feature_2, groups, lags):\n",
    "    # Initialize a dictionary to store shifted arrays for each lag.\n",
    "    pooled = {lag: {\"x1\": [], \"x2\": []} for lag in lags}\n",
    "    \n",
    "    for group in groups:\n",
    "        arr1 = group[feature_1].values\n",
    "        arr2 = group[feature_2].values\n",
    "        n = len(group)\n",
    "        \n",
    "        # For lag 0, require at least 2 data points.\n",
    "        if n >= 2:\n",
    "            pooled[0][\"x1\"].append(arr1)\n",
    "            pooled[0][\"x2\"].append(arr2)\n",
    "            \n",
    "        # For positive lags, slice the arrays if enough data exists.\n",
    "        for lag in lags[1:]:\n",
    "            if n > lag:\n",
    "                pooled[lag][\"x1\"].append(arr1[lag:])\n",
    "                pooled[lag][\"x2\"].append(arr2[:-lag])\n",
    "    \n",
    "    # Compute correlations for each lag.\n",
    "    results = []\n",
    "    for lag in lags:\n",
    "        if not pooled[lag][\"x1\"]:\n",
    "            corr = np.nan\n",
    "        else:\n",
    "            x1 = np.concatenate(pooled[lag][\"x1\"])\n",
    "            x2 = np.concatenate(pooled[lag][\"x2\"])\n",
    "            # Check for sufficient data and non-constant arrays.\n",
    "            if len(x1) < 2 or np.std(x1) == 0 or np.std(x2) == 0:\n",
    "                corr = np.nan\n",
    "            else:\n",
    "                corr = np.corrcoef(x1, x2)[0, 1]\n",
    "        results.append({\n",
    "            'feature_1': feature_1,\n",
    "            'feature_2': feature_2,\n",
    "            'lag': lag,\n",
    "            'correlation': corr\n",
    "        })\n",
    "    return results\n",
    "\n",
    "def pooled_lagged_correlation_main():\n",
    "    # Load dataset\n",
    "    df = pd.read_csv(\"./sepsis_final_data_RAW_withTimes_continuous.csv\")\n",
    "    \n",
    "    # Ensure data are sorted by patient ('traj') and time ('step')\n",
    "    df = df.sort_values(['traj', 'step'])\n",
    "    \n",
    "    # Identify observation columns (those starting with \"o\")\n",
    "    obs_cols = [col for col in df.columns if col.startswith('o')]\n",
    "    # Remove specific columns not needed\n",
    "    cols_to_remove = [\"o:max_dose_vaso\", \"o:input_4hourly\", \"o:gender\", \"o:re_admission\", \"o:age\", \"o:SOFA\"]\n",
    "    for col in cols_to_remove:\n",
    "        if col in obs_cols:\n",
    "            obs_cols.remove(col)\n",
    "    \n",
    "    # Define two sets of features.\n",
    "    sel_cols_1 = ['o:mechvent', 'o:GCS', 'o:FiO2_1', 'o:paO2', 'o:PaO2_FiO2',\n",
    "                  'o:Total_bili', 'o:output_4hourly']\n",
    "    sel_cols_2 = [col for col in obs_cols if col not in sel_cols_1]\n",
    "    \n",
    "    # Define the lags to search (0 to 3).\n",
    "    lags = [0, 1, 2, 3]\n",
    "    \n",
    "    # Cache groups: list of DataFrames for each unique 'traj', sorted by 'step'.\n",
    "    groups = [group.sort_values('step') for _, group in df.groupby('traj')]\n",
    "    \n",
    "    # Prepare the list of feature pairs (compute each pair only once).\n",
    "    feature_pairs = [(f1, f2) for f1 in sel_cols_1 for f2 in sel_cols_2]\n",
    "    \n",
    "    # Process feature pairs in parallel.\n",
    "    all_results = Parallel(n_jobs=-1)(\n",
    "        delayed(process_feature_pair)(f1, f2, groups, lags)\n",
    "        for f1, f2 in tqdm(feature_pairs, desc=\"Processing feature pairs\")\n",
    "    )\n",
    "    \n",
    "    # Flatten the list of lists into a single list of dictionaries.\n",
    "    flattened_results = [item for sublist in all_results for item in sublist]\n",
    "    \n",
    "    # Convert the results to a DataFrame and return.\n",
    "    results_df = pd.DataFrame(flattened_results)\n",
    "    return results_df\n",
    "\n",
    "if __name__ == '__main__':\n",
    "    results_df = pooled_lagged_correlation_main()\n",
    "    # For demonstration, print the first few rows of the DataFrame.\n",
    "    print(results_df.head())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>feature_1</th>\n",
       "      <th>feature_2</th>\n",
       "      <th>lag</th>\n",
       "      <th>correlation</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>104</th>\n",
       "      <td>o:mechvent</td>\n",
       "      <td>o:SpO2</td>\n",
       "      <td>0</td>\n",
       "      <td>0.217191</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>105</th>\n",
       "      <td>o:mechvent</td>\n",
       "      <td>o:SpO2</td>\n",
       "      <td>1</td>\n",
       "      <td>0.218067</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>106</th>\n",
       "      <td>o:mechvent</td>\n",
       "      <td>o:SpO2</td>\n",
       "      <td>2</td>\n",
       "      <td>0.203638</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>128</th>\n",
       "      <td>o:mechvent</td>\n",
       "      <td>o:input_total</td>\n",
       "      <td>0</td>\n",
       "      <td>0.247055</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>129</th>\n",
       "      <td>o:mechvent</td>\n",
       "      <td>o:input_total</td>\n",
       "      <td>1</td>\n",
       "      <td>0.233322</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>130</th>\n",
       "      <td>o:mechvent</td>\n",
       "      <td>o:input_total</td>\n",
       "      <td>2</td>\n",
       "      <td>0.229206</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>131</th>\n",
       "      <td>o:mechvent</td>\n",
       "      <td>o:input_total</td>\n",
       "      <td>3</td>\n",
       "      <td>0.229378</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>948</th>\n",
       "      <td>o:output_4hourly</td>\n",
       "      <td>o:output_total</td>\n",
       "      <td>0</td>\n",
       "      <td>0.279923</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>949</th>\n",
       "      <td>o:output_4hourly</td>\n",
       "      <td>o:output_total</td>\n",
       "      <td>1</td>\n",
       "      <td>0.239853</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>950</th>\n",
       "      <td>o:output_4hourly</td>\n",
       "      <td>o:output_total</td>\n",
       "      <td>2</td>\n",
       "      <td>0.224943</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>951</th>\n",
       "      <td>o:output_4hourly</td>\n",
       "      <td>o:output_total</td>\n",
       "      <td>3</td>\n",
       "      <td>0.222051</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "            feature_1       feature_2  lag  correlation\n",
       "104        o:mechvent          o:SpO2    0     0.217191\n",
       "105        o:mechvent          o:SpO2    1     0.218067\n",
       "106        o:mechvent          o:SpO2    2     0.203638\n",
       "128        o:mechvent   o:input_total    0     0.247055\n",
       "129        o:mechvent   o:input_total    1     0.233322\n",
       "130        o:mechvent   o:input_total    2     0.229206\n",
       "131        o:mechvent   o:input_total    3     0.229378\n",
       "948  o:output_4hourly  o:output_total    0     0.279923\n",
       "949  o:output_4hourly  o:output_total    1     0.239853\n",
       "950  o:output_4hourly  o:output_total    2     0.224943\n",
       "951  o:output_4hourly  o:output_total    3     0.222051"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "results_df_filter = results_df[results_df['correlation'].abs() >= 0.2]\n",
    "results_df_filter"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.read_csv(\"./sepsis_final_data_RAW_withTimes_continuous.csv\")\n",
    "ls_cols_sel = ['traj', 'step', \"o:gender\", \"o:re_admission\", \"o:age\", 'o:Weight_kg', 'o:mechvent', 'o:GCS', 'o:FiO2_1', 'o:paO2', 'o:PaO2_FiO2',  'o:Total_bili', 'o:output_4hourly', 'o:output_total', 'o:input_total', 'o:SpO2', 'o:max_dose_vaso', 'o:input_4hourly', 'o:SOFA', 'r:reward']\n",
    "df = df[ls_cols_sel]\n",
    "df.rename(columns={'o:max_dose_vaso': 'a:max_dose_vaso'}, inplace=True)\n",
    "df.rename(columns={'o:input_4hourly': 'a:input_4hourly'}, inplace=True)\n",
    "df.rename(columns={'o:SOFA': 'r:SOFA'}, inplace=True)\n",
    "df['r:SOFA'] = -df['r:SOFA'].astype(np.float32)\n",
    "df.rename(columns={'o:Weight_kg': 'co:Weight_kg'}, inplace=True)\n",
    "df.to_csv(\"sepsis_final_RAW_continuous_13_weights.csv\", index=False)\n",
    "df = df.drop(columns=['co:Weight_kg'])\n",
    "df.to_csv(\"sepsis_final_RAW_continuous_13.csv\", index=False)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
