{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "686bfb5c",
   "metadata": {},
   "source": [
    "# Tuning hyperparameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "93c85d43",
   "metadata": {},
   "outputs": [],
   "source": [
    "# import pandas as pd\n",
    "# import numpy as np\n",
    "# from sklearn.preprocessing import StandardScaler\n",
    "# from lifelines import CoxPHFitter\n",
    "# import matplotlib.pyplot as plt\n",
    "# from tqdm import tqdm\n",
    "\n",
    "# # Load data\n",
    "# df = pd.read_csv('processed_data.csv')\n",
    "\n",
    "# # Define survival analysis function\n",
    "# def survival_analysis(time_points, duration, features, treatment, event, penalizer):\n",
    "#     \"\"\"\n",
    "#     Perform survival analysis using Cox Proportional Hazards model with treatment interaction\n",
    "    \n",
    "#     Parameters:\n",
    "#     time_points (array): Time points at which to estimate survival probabilities\n",
    "#     duration (array): Observed survival times\n",
    "#     features (array): Feature matrix (n_samples x n_features)\n",
    "#     treatment (array): Treatment indicator (0/1)\n",
    "#     event (array): Event indicator (1=event, 0=censored)\n",
    "#     penalizer (float): L2 regularization parameter\n",
    "    \n",
    "#     Returns:\n",
    "#     log_survival (array): Log-transformed survival probabilities\n",
    "#     concordance (float): Concordance index (c-index)\n",
    "#     \"\"\"\n",
    "#     # Transform treatment to [-1, 1] scale\n",
    "#     treatment = np.reshape(treatment, (-1, 1))\n",
    "#     treatment_scaled = 2 * treatment - 1\n",
    "    \n",
    "#     # Compute treatment-weighted features\n",
    "#     weighted_features = np.multiply(features, treatment_scaled)\n",
    "    \n",
    "#     # Prepare DataFrame for model fitting\n",
    "#     feature_names = [f'X{i+1}_treatment' for i in range(features.shape[1])]\n",
    "#     model_data = pd.DataFrame(weighted_features, columns=feature_names)\n",
    "    \n",
    "#     # Add survival time and event columns\n",
    "#     model_data['duration'] = duration\n",
    "#     model_data['event'] = event\n",
    "    \n",
    "#     # Fit CoxPH model with regularization\n",
    "#     cph = CoxPHFitter(penalizer=penalizer)\n",
    "#     cph.fit(model_data, duration_col='duration', event_col='event')\n",
    "    \n",
    "#     # Calculate concordance index\n",
    "#     concordance = cph.concordance_index_\n",
    "    \n",
    "#     # Predict survival function\n",
    "#     survival_probs = cph.predict_survival_function(model_data, times=time_points)\n",
    "    \n",
    "#     return np.log(survival_probs), concordance\n",
    "\n",
    "# # Hyperparameter tuning: Find optimal penalizer\n",
    "# def find_optimal_penalizer(data, time_points, penalizer_range):\n",
    "#     \"\"\"\n",
    "#     Find the optimal penalizer value by maximizing concordance index\n",
    "    \n",
    "#     Returns:\n",
    "#     best_penalizer (float): Optimal penalizer value\n",
    "#     \"\"\"\n",
    "#     best_penalizer = None\n",
    "#     best_concordance_sum = -np.inf\n",
    "#     results = []\n",
    "    \n",
    "#     # Progress bar for hyperparameter search\n",
    "#     with tqdm(total=len(penalizer_range), desc=\"Finding optimal penalizer\") as pbar:\n",
    "#         for penalizer in penalizer_range:\n",
    "#             concordance_sum = 0\n",
    "#             weekly_concordance = []\n",
    "            \n",
    "#             # Iterate over each study week\n",
    "#             for week in data['study_week'].unique():\n",
    "#                 week_data = data[data['study_week'] == week].copy()\n",
    "                \n",
    "#                 # Extract features and targets\n",
    "#                 features = week_data[['avg_step', 'avg_sleep']].values\n",
    "#                 duration = week_data['avg_mood'].values\n",
    "#                 treatment = week_data['action'].values\n",
    "#                 event = week_data['censor_mood'].values\n",
    "                \n",
    "#                 # Calculate survival and concordance\n",
    "#                 _, concordance = survival_analysis(\n",
    "#                     time_points, duration, features, treatment, event, penalizer\n",
    "#                 )\n",
    "#                 concordance_sum += concordance\n",
    "#                 weekly_concordance.append(concordance)\n",
    "            \n",
    "#             # Record results\n",
    "#             results.append({\n",
    "#                 'penalizer': penalizer,\n",
    "#                 'concordance_sum': concordance_sum,\n",
    "#                 'avg_concordance': concordance_sum / len(data['study_week'].unique()),\n",
    "#                 'weekly_concordance': weekly_concordance\n",
    "#             })\n",
    "            \n",
    "#             # Update best penalizer\n",
    "#             if concordance_sum > best_concordance_sum:\n",
    "#                 best_concordance_sum = concordance_sum\n",
    "#                 best_penalizer = penalizer\n",
    "            \n",
    "#             # Update progress bar\n",
    "#             pbar.update(1)\n",
    "#             pbar.set_postfix({\n",
    "#                 'Best penalizer': best_penalizer,\n",
    "#                 'Best concordance': f'{best_concordance_sum:.4f}'\n",
    "#             })\n",
    "    \n",
    "#     # Plot penalizer performance\n",
    "#     plt.figure(figsize=(10, 6))\n",
    "#     plt.plot([r['penalizer'] for r in results], [r['concordance_sum'] for r in results], 'o-')\n",
    "#     plt.xlabel('Penalizer (L2 regularization)')\n",
    "#     plt.ylabel('Sum of Concordance Index')\n",
    "#     plt.title('Effect of Penalizer on Model Performance')\n",
    "#     plt.grid(True)\n",
    "#     plt.show()\n",
    "    \n",
    "#     print(f\"\\nOptimal penalizer: {best_penalizer}, Concordance sum: {best_concordance_sum:.4f}\")\n",
    "#     return best_penalizer\n",
    "\n",
    "# # Process survival analysis with optimal parameters\n",
    "# def process_survival_analysis(data, time_points, penalizer):\n",
    "#     \"\"\"\n",
    "#     Perform survival analysis for each study week using optimal parameters\n",
    "#     \"\"\"\n",
    "#     concordance_results = []\n",
    "#     all_survival_probs = []\n",
    "    \n",
    "#     # Iterate over each study week\n",
    "#     for week in data['study_week'].unique():\n",
    "#         week_data = data[data['study_week'] == week].copy()\n",
    "        \n",
    "#         # Extract features and targets\n",
    "#         features = week_data[['avg_step', 'avg_sleep']].values\n",
    "#         duration = week_data['avg_mood'].values\n",
    "#         treatment = week_data['action'].values\n",
    "#         event = week_data['censor_mood'].values\n",
    "        \n",
    "#         # Calculate survival probabilities and concordance\n",
    "#         survival_probs, concordance = survival_analysis(\n",
    "#             time_points, duration, features, treatment, event, penalizer\n",
    "#         )\n",
    "        \n",
    "#         # Record results\n",
    "#         concordance_results.append({\n",
    "#             'study_week': week,\n",
    "#             'concordance_index': concordance\n",
    "#         })\n",
    "        \n",
    "#         all_survival_probs.append({\n",
    "#             'week': week,\n",
    "#             'survival_probs': survival_probs,\n",
    "#             'participant_ids': week_data['STUDY_PRTCPT_ID'].values\n",
    "#         })\n",
    "    \n",
    "#     # Calculate average concordance\n",
    "#     avg_concordance = np.mean([r['concordance_index'] for r in concordance_results])\n",
    "#     print(f\"\\nAverage concordance index across all weeks: {avg_concordance:.4f}\")\n",
    "    \n",
    "#     return all_survival_probs, concordance_results\n",
    "\n",
    "# # Define time points for survival estimation\n",
    "# time_points = np.arange(5, 71, 5)  # From 5 to 70 in steps of 5\n",
    "\n",
    "# # Define penalizer search range\n",
    "# penalizer_range = np.linspace(0, 0.025, 100)  # 100 points between 0 and 0.025\n",
    "\n",
    "# # Find optimal penalizer\n",
    "# best_penalizer = find_optimal_penalizer(df, time_points, penalizer_range)\n",
    "\n",
    "# # Perform final survival analysis with optimal penalizer\n",
    "# survival_probs, concordance_results = process_survival_analysis(\n",
    "#     df, time_points, best_penalizer\n",
    "# )\n",
    "\n",
    "# # Optional: Print sample survival probabilities\n",
    "# # ... (add your code here)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a657b723",
   "metadata": {},
   "source": [
    "# Data Transition"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4a6bf7b3",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "from sklearn.preprocessing import MinMaxScaler\n",
    "from lifelines import CoxPHFitter\n",
    "\n",
    "# Read data\n",
    "df = pd.read_csv('processed_data.csv')\n",
    "\n",
    "# Define the function to compute the Cox survival model\n",
    "def Survival_T(time, Y, X, A, Delta):  # Cox model\n",
    "    A = np.reshape(A, (A.shape[0], 1))\n",
    "    A = 2 * A - 1  # Make A -1 or 1\n",
    "\n",
    "    # Compute (2A - 1) * X\n",
    "    X_adjusted = np.multiply(X, A.reshape(-1, 1))  # (2A - 1) * X\n",
    "\n",
    "    # Create a DataFrame containing adjusted X and A\n",
    "    data = pd.DataFrame(X_adjusted, columns=[f'X{i + 1}*A' for i in range(X.shape[1])])\n",
    "    pre_data = pd.DataFrame(X_adjusted, columns=[f'X{i + 1}*A' for i in range(X.shape[1])])\n",
    "\n",
    "    # Add survival time and event indicator columns\n",
    "    data['Y'] = Y.flatten()\n",
    "    data['Delta'] = Delta.flatten()\n",
    "\n",
    "    pre_data['Y'] = Y.flatten()\n",
    "    pre_data['Delta'] = Delta.flatten()\n",
    "    \n",
    "    # Fit Cox model with regularization\n",
    "    cox = CoxPHFitter(penalizer=0.0157)\n",
    "    cox.fit(data, duration_col='Y', event_col='Delta')\n",
    "\n",
    "    # Concordance index\n",
    "    concordance_index = cox.concordance_index_\n",
    "    print(cox.concordance_index_)\n",
    "\n",
    "    # Estimate the survival function at time points\n",
    "    st_estimate = cox.predict_survival_function(pre_data, times=time)\n",
    "    # print(np.log(st_estimate.values))\n",
    "    return np.log(st_estimate.values), concordance_index\n",
    "\n",
    "\n",
    "# Define a function to handle weekly survival analysis\n",
    "def process_for_survival(df, times):\n",
    "    survival_results = []  # To save survival analysis results\n",
    "    concordance_results = []  # To save concordance index results for each time point\n",
    "\n",
    "    # Loop through data for each week\n",
    "    for week in df['study_week'].unique():\n",
    "        #print(\"Processing week:\", week)\n",
    "        week_data = df[df['study_week'] == week]\n",
    "\n",
    "        # Extract X, Y, A, Delta (including newly added non-linear features)\n",
    "        \n",
    "        X = week_data[['avg_step', 'avg_sleep']].values\n",
    "        \n",
    "        Y = week_data['avg_mood'].values\n",
    "        A = week_data['action'].values\n",
    "        Delta = week_data['censor_mood'].values\n",
    "\n",
    "        # Compute survival function and concordance index\n",
    "        survival_probs, concordance_index = Survival_T(times, Y, X, A, Delta)\n",
    "\n",
    "        # Record the concordance index\n",
    "        concordance_results.append({\n",
    "            'study_week': week,\n",
    "            'concordance_index': concordance_index\n",
    "        })\n",
    "\n",
    "        # Save the survival probability results\n",
    "        for i, participant_id in enumerate(week_data['STUDY_PRTCPT_ID'].values):\n",
    "            row = {\n",
    "                'STUDY_PRTCPT_ID': participant_id,\n",
    "                'study_week': week,\n",
    "                'avg_step': week_data.iloc[i]['avg_step'],\n",
    "                'avg_sleep': week_data.iloc[i]['avg_sleep'],\n",
    "                'avg_mood': Y[i],\n",
    "                'censor_mood': Delta[i],\n",
    "                'action': A[i],\n",
    "            }\n",
    "            for t, prob in zip(times, survival_probs[:, i]):\n",
    "                row[f'survival_probability_time_{t}'] = prob\n",
    "\n",
    "            survival_results.append(row)\n",
    "\n",
    "    # Convert to DataFrame and save survival analysis results\n",
    "    survival_df = pd.DataFrame(survival_results)\n",
    "    survival_df.to_csv('survival_results_cox.csv', index=False)\n",
    "\n",
    "    # Convert to DataFrame and save concordance index results\n",
    "    concordance_df = pd.DataFrame(concordance_results)\n",
    "    concordance_df.to_csv('concordance_index_results_cox.csv', index=False)\n",
    "\n",
    "\n",
    "# Define time points\n",
    "times = np.arange(5, 71, 5)# For example, time points from 5 to 60 with an interval of 5\n",
    "\n",
    "# Process the data and compute survival probabilities and concordance index\n",
    "process_for_survival(df, times)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7ec7b459",
   "metadata": {},
   "source": [
    "# Cox Model Train and Prediction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e02e0e38",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "import pandas as pd\n",
    "from collections import deque\n",
    "\n",
    "# Define DQN Network\n",
    "class DQN(nn.Module):\n",
    "    def __init__(self, input_dim, output_dim):\n",
    "        super(DQN, self).__init__()\n",
    "        self.fc1 = nn.Linear(input_dim, 64)\n",
    "        self.fc2 = nn.Linear(64, 64)\n",
    "        self.fc3 = nn.Linear(64, output_dim)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = torch.relu(self.fc1(x))\n",
    "        x = torch.relu(self.fc2(x))\n",
    "        return self.fc3(x)\n",
    "\n",
    "# DQN Agent with Soft Target Network Updates\n",
    "class SoftUpdateDQNAgent:\n",
    "    def __init__(self, input_dim, output_dim, gamma=0.99, tau=0.001, lr=0.001):\n",
    "        self.policy_net = DQN(input_dim, output_dim)  # Policy network for action selection\n",
    "        self.target_net = DQN(input_dim, output_dim)  # Target network for stability\n",
    "        self.target_net.load_state_dict(self.policy_net.state_dict())\n",
    "        self.target_net.eval()  # Set target network to evaluation mode\n",
    "\n",
    "        self.gamma = gamma  # Discount factor for future rewards\n",
    "        self.tau = tau      # Target network update rate\n",
    "        self.optimizer = optim.Adam(self.policy_net.parameters(), lr=lr)\n",
    "        self.criterion = nn.MSELoss()  # Mean squared error loss\n",
    "\n",
    "    def update_target_network(self):\n",
    "        \"\"\"Soft update of target network parameters\"\"\"\n",
    "        for target_param, policy_param in zip(self.target_net.parameters(), self.policy_net.parameters()):\n",
    "            target_param.data.copy_(self.tau * policy_param.data + (1 - self.tau) * target_param.data)\n",
    "\n",
    "    def train(self, state, action, reward, next_state):\n",
    "        \"\"\"Perform a single training step\"\"\"\n",
    "        # Convert to tensors\n",
    "        state = torch.FloatTensor(state)\n",
    "        action = torch.LongTensor([action])\n",
    "        reward = torch.FloatTensor([reward])\n",
    "        next_state = torch.FloatTensor(next_state)\n",
    "\n",
    "        # Compute Q-values\n",
    "        q_values = self.policy_net(state)\n",
    "        q_value = q_values.gather(0, action)\n",
    "\n",
    "        # Compute target Q-value using target network\n",
    "        next_q_values = self.target_net(next_state).detach()\n",
    "        max_next_q_value = next_q_values.max()\n",
    "\n",
    "        # Compute target Q-value\n",
    "        target_q_value = reward + self.gamma * max_next_q_value\n",
    "\n",
    "        # Compute loss and optimize\n",
    "        loss = self.criterion(q_value, target_q_value)\n",
    "        \n",
    "        self.optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        self.optimizer.step()\n",
    "\n",
    "        # Update target network\n",
    "        self.update_target_network()\n",
    "\n",
    "    def predict(self, state):\n",
    "        \"\"\"Predict best action based on current policy\"\"\"\n",
    "        state = torch.FloatTensor(state)\n",
    "        q_values = self.policy_net(state)\n",
    "        action = torch.argmax(q_values).item()\n",
    "        return action\n",
    "\n",
    "def main():\n",
    "    # Specify CSV file path\n",
    "    csv_file = \"survival_results_cox.csv\"  # Update with your file path\n",
    "    \n",
    "    # Load data from CSV\n",
    "    try:\n",
    "        data = pd.read_csv(csv_file)\n",
    "        print(f\"Data loaded: {csv_file}, Shape: {data.shape}\")\n",
    "        print(f\"Columns: {data.columns.tolist()}\")\n",
    "    except FileNotFoundError:\n",
    "        print(f\"Error: File {csv_file} not found!\")\n",
    "        return\n",
    "\n",
    "    # Extract relevant data\n",
    "    states = data[['avg_step', 'avg_sleep']].values\n",
    "    actions = data['action'].values\n",
    "\n",
    "    # Identify survival probability columns\n",
    "    survival_columns = [col for col in data.columns if 'survival_probability_time' in col]\n",
    "    \n",
    "    if not survival_columns:\n",
    "        print(\"No survival probability columns found. Check CSV file format!\")\n",
    "        return\n",
    "\n",
    "    # Create result DataFrame\n",
    "    result_df = data.copy()\n",
    "    \n",
    "    # Train and predict for each time point\n",
    "    for col in survival_columns:\n",
    "        rewards = data[col].values\n",
    "        input_dim = states.shape[1]\n",
    "        output_dim = len(np.unique(actions))\n",
    "        agent = SoftUpdateDQNAgent(input_dim, output_dim)\n",
    "\n",
    "        # Training loop\n",
    "        for i in range(len(states) - 1):\n",
    "            state = states[i]\n",
    "            action = actions[i]\n",
    "            reward = rewards[i]\n",
    "            next_state = states[i + 1]\n",
    "            agent.train(state, action, reward, next_state)\n",
    "\n",
    "        # Predict actions\n",
    "        predicted_actions = [agent.predict(state) for state in states]\n",
    "\n",
    "        # Calculate prediction accuracy\n",
    "        accuracy = np.mean(np.array(predicted_actions) == actions)\n",
    "        print(f\"Time: {col}, Prediction accuracy: {accuracy:.4f}\")\n",
    "        \n",
    "        # Extract time value from column name\n",
    "        time_value = col.split('_')[-1]\n",
    "        \n",
    "        # Add predictions to result DataFrame\n",
    "        result_df[f'predicted_action_{time_value}'] = predicted_actions\n",
    "\n",
    "    # Save results to new CSV file\n",
    "    output_file = \"rl_data_cox.csv\"\n",
    "    result_df.to_csv(output_file, index=False)\n",
    "    print(f\"Prediction results saved to {output_file}\")\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    main()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bc33ff55",
   "metadata": {},
   "source": [
    "# Estimate P with prediction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1ebc2cf9",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "from sklearn.preprocessing import MinMaxScaler\n",
    "from lifelines import LogNormalAFTFitter\n",
    "\n",
    "# Read data\n",
    "df = pd.read_csv('processed_data.csv')\n",
    "new_df = pd.read_csv('rl_data_aah.csv')\n",
    "t_df = pd.read_csv('rl_data_t.csv')\n",
    "\n",
    "# Define the function to compute the Cox survival model\n",
    "def Survival_T(time, Y, X, A, Delta, new_A=None):  # Cox model\n",
    "    A = np.reshape(A, (A.shape[0], 1))\n",
    "    A = 2 * A - 1  # Make A -1 or 1\n",
    "\n",
    "    # Compute (2A - 1) * X for fitting\n",
    "    X_adjusted_fit = np.multiply(X, A.reshape(-1, 1))  # (2A - 1) * X\n",
    "\n",
    "    # Create a DataFrame containing adjusted X and A for fitting\n",
    "    data = pd.DataFrame(X_adjusted_fit, columns=[f'X{i + 1}*A' for i in range(X.shape[1])])\n",
    "\n",
    "    # Add survival time and event indicator columns\n",
    "    data['Y'] = Y.flatten()\n",
    "    data['Delta'] = Delta.flatten()\n",
    "\n",
    "    # Fit Cox model with regularization\n",
    "    cox = LogNormalAFTFitter()\n",
    "    cox.fit(data, duration_col='Y', event_col='Delta')\n",
    "\n",
    "    if new_A is not None:\n",
    "        new_A = np.reshape(new_A, (new_A.shape[0], 1))\n",
    "        new_A = 2 * new_A - 1  # Make new_A -1 or 1\n",
    "        X_adjusted_pred = np.multiply(X, new_A.reshape(-1, 1))  # (2new_A - 1) * X\n",
    "        pre_data = pd.DataFrame(X_adjusted_pred, columns=[f'X{i + 1}*A' for i in range(X.shape[1])])\n",
    "    else:\n",
    "        pre_data = pd.DataFrame(X_adjusted_fit, columns=[f'X{i + 1}*A' for i in range(X.shape[1])])\n",
    "\n",
    "    pre_data['Y'] = Y.flatten()\n",
    "    pre_data['Delta'] = Delta.flatten()\n",
    "\n",
    "    # Estimate the survival function at time points\n",
    "    st_estimate = cox.predict_survival_function(pre_data, times=time)\n",
    "\n",
    "    # 将 inf 替换为 -10\n",
    "    result = np.nan_to_num(np.log(st_estimate.values), posinf=-10, neginf=-10)\n",
    "\n",
    "    return result\n",
    "\n",
    "\n",
    "# Define a function to handle weekly survival analysis\n",
    "def process_for_survival(df, new_df, times):\n",
    "    survival_results = []  # To save survival analysis results\n",
    "\n",
    "    # Loop through data for each week\n",
    "    for week in df['study_week'].unique():\n",
    "        week_data = df[df['study_week'] == week]\n",
    "        week_action = new_df[df['study_week'] == week]\n",
    "        ttdf = t_df[t_df['study_week'] == week]\n",
    "        \n",
    "        X = week_data[['avg_step', 'avg_sleep']].values\n",
    "        \n",
    "        Y = week_data['avg_mood'].values\n",
    "        A = week_data['action'].values\n",
    "        Delta = week_data['censor_mood'].values\n",
    "\n",
    "        # Compute survival function for original data\n",
    "        original_st_estimate = Survival_T(times, Y, X, A, Delta)\n",
    "\n",
    "        for t in times:\n",
    "            new_A = week_action[f'predict_action_{t}'].values\n",
    "            new_st_estimate = Survival_T(times, Y, X, A, Delta, new_A)\n",
    "            \n",
    "            predict_action_t = ttdf['predict_action'].values\n",
    "            predict_action_t_st_estimate = Survival_T(times, Y, X, A, Delta, predict_action_t)\n",
    "\n",
    "            for i, participant_id in enumerate(week_data['STUDY_PRTCPT_ID'].values):\n",
    "                row = {\n",
    "                    'STUDY_PRTCPT_ID': participant_id,\n",
    "                    'study_week': week\n",
    "                }\n",
    "                for t, prob in zip(times, original_st_estimate[:, i]):\n",
    "                    row[f'original_st_estimate_{t}'] = prob\n",
    "                for t, prob in zip(times, new_st_estimate[:, i]):\n",
    "                    row[f'new_st_estimate_{t}'] = prob\n",
    "                for t, prob in zip(times, predict_action_t_st_estimate[:, i]):\n",
    "                    row[f'predict_action_t_st_estimate_{t}'] = prob\n",
    "                survival_results.append(row)\n",
    "\n",
    "    # Convert to DataFrame and save survival analysis results\n",
    "    survival_df = pd.DataFrame(survival_results)\n",
    "    survival_df.to_csv('data_aah.csv', index=False)\n",
    "\n",
    "\n",
    "times = np.arange(5, 71, 5)# For example, time points from 5 to 60 with an interval of 5\n",
    "\n",
    "# Process the data and compute survival probabilities\n",
    "process_for_survival(df, new_df, times)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "rl",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
