{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "08251129",
   "metadata": {},
   "source": [
    "# Tuning hyperparameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b7e6b73d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# import pandas as pd\n",
    "# import numpy as np\n",
    "# from sklearn.preprocessing import StandardScaler\n",
    "# from lifelines import AalenAdditiveFitter\n",
    "# import matplotlib.pyplot as plt\n",
    "# from tqdm import tqdm\n",
    "\n",
    "# # Load data\n",
    "# df = pd.read_csv('processed_data.csv')\n",
    "\n",
    "# # Define survival analysis function\n",
    "# def Survival_T(time, Y, X, A, Delta, penalizer):\n",
    "#     A = np.reshape(A, (A.shape[0], 1))\n",
    "#     A = 2 * A - 1  # Convert A to -1 or 1\n",
    "\n",
    "#     # Compute (2A - 1) * X\n",
    "#     X_adjusted = np.multiply(X, A.reshape(-1, 1))\n",
    "\n",
    "#     # Create DataFrame with adjusted X and A\n",
    "#     data = pd.DataFrame(X_adjusted, columns=[f'X{i + 1}*A' for i in range(X.shape[1])])\n",
    "#     pre_data = pd.DataFrame(X_adjusted, columns=[f'X{i + 1}*A' for i in range(X.shape[1])])\n",
    "\n",
    "#     # Add survival time and event indicator columns\n",
    "#     data['Y'] = Y.flatten()\n",
    "#     data['Delta'] = Delta.flatten()\n",
    "\n",
    "#     pre_data['Y'] = Y.flatten()\n",
    "#     pre_data['Delta'] = Delta.flatten()\n",
    "    \n",
    "#     # Fit Aalen additive model with regularization\n",
    "#     aaf = AalenAdditiveFitter(coef_penalizer=penalizer)\n",
    "#     aaf.fit(data, duration_col='Y', event_col='Delta')\n",
    "\n",
    "#     # Calculate concordance index\n",
    "#     concordance_index = aaf.concordance_index_\n",
    "\n",
    "#     # Estimate survival function\n",
    "#     st_estimate = aaf.predict_survival_function(pre_data, times=time)\n",
    "#     st_estimate_values = st_estimate\n",
    "\n",
    "#     return np.log(st_estimate_values), concordance_index\n",
    "\n",
    "# # Find optimal coef_penalizer value\n",
    "# def find_best_penalizer(df, times, penalizer_range):\n",
    "#     best_penalizer = None\n",
    "#     best_concordance_sum = -np.inf\n",
    "#     results = []\n",
    "    \n",
    "#     # Create progress bar\n",
    "#     with tqdm(total=len(penalizer_range), desc=\"Finding optimal penalizer\") as pbar:\n",
    "#         for penalizer in penalizer_range:\n",
    "#             concordance_sum = 0\n",
    "            \n",
    "#             # Loop through weekly data\n",
    "#             for week in df['study_week'].unique():\n",
    "#                 week_data = df[df['study_week'] == week]\n",
    "\n",
    "#                 # Extract features and labels\n",
    "#                 X = week_data[['avg_step', 'avg_sleep']].values\n",
    "#                 Y = week_data['avg_mood'].values\n",
    "#                 A = week_data['action'].values\n",
    "#                 Delta = week_data['censor_mood'].values\n",
    "\n",
    "#                 # Calculate survival function and concordance index\n",
    "#                 _, concordance_index = Survival_T(times, Y, X, A, Delta, penalizer)\n",
    "#                 concordance_sum += concordance_index\n",
    "            \n",
    "#             # Record results for current penalizer\n",
    "#             results.append({\n",
    "#                 'penalizer': penalizer,\n",
    "#                 'concordance_sum': concordance_sum,\n",
    "#                 'avg_concordance': concordance_sum / len(df['study_week'].unique())\n",
    "#             })\n",
    "            \n",
    "#             # Update optimal penalizer\n",
    "#             if concordance_sum > best_concordance_sum:\n",
    "#                 best_concordance_sum = concordance_sum\n",
    "#                 best_penalizer = penalizer\n",
    "            \n",
    "#             # Update progress bar\n",
    "#             pbar.update(1)\n",
    "#             pbar.set_postfix({'Best penalizer': best_penalizer, 'Best concordance sum': f'{best_concordance_sum:.4f}'})\n",
    "    \n",
    "#     # Plot results\n",
    "#     plt.figure(figsize=(10, 6))\n",
    "#     plt.plot([r['penalizer'] for r in results], [r['concordance_sum'] for r in results], marker='o')\n",
    "#     plt.xlabel('Penalizer (coef_penalizer)')\n",
    "#     plt.ylabel('Sum of Concordance Index')\n",
    "#     plt.title('Effect of Penalizer on Model Performance')\n",
    "#     plt.grid(True)\n",
    "#     plt.show()\n",
    "    \n",
    "#     print(f\"\\nOptimal penalizer: {best_penalizer}, Concordance sum: {best_concordance_sum:.4f}\")\n",
    "#     return best_penalizer\n",
    "\n",
    "# # Perform survival analysis with optimal parameters\n",
    "# def process_for_survival(df, times, penalizer):\n",
    "#     concordance_results = []\n",
    "#     all_survival_probs = []\n",
    "\n",
    "#     # Loop through weekly data\n",
    "#     for week in df['study_week'].unique():\n",
    "#         week_data = df[df['study_week'] == week]\n",
    "\n",
    "#         # Extract features and labels\n",
    "#         X = week_data[['avg_step', 'avg_sleep']].values\n",
    "#         Y = week_data['avg_mood'].values\n",
    "#         A = week_data['action'].values\n",
    "#         Delta = week_data['censor_mood'].values\n",
    "\n",
    "#         # Calculate survival function and concordance index\n",
    "#         survival_probs, concordance_index = Survival_T(times, Y, X, A, Delta, penalizer)\n",
    "        \n",
    "#         # Record concordance index\n",
    "#         concordance_results.append({\n",
    "#             'study_week': week,\n",
    "#             'concordance_index': concordance_index\n",
    "#         })\n",
    "        \n",
    "#         # Save survival probability results\n",
    "#         all_survival_probs.append({\n",
    "#             'week': week,\n",
    "#             'survival_probs': survival_probs,\n",
    "#             'participant_ids': week_data['STUDY_PRTCPT_ID'].values\n",
    "#         })\n",
    "\n",
    "#     # Calculate and print average concordance index\n",
    "#     avg_concordance = np.mean([r['concordance_index'] for r in concordance_results])\n",
    "#     print(f\"\\nAverage concordance index across all weeks: {avg_concordance:.4f}\")\n",
    "    \n",
    "#     return all_survival_probs, concordance_results\n",
    "\n",
    "# # Set time points\n",
    "# times = np.arange(5, 71, 5)\n",
    "\n",
    "# # Define penalizer range to test\n",
    "# penalizer_range = np.linspace(0.1, 100000.0, 50)\n",
    "\n",
    "# # Find optimal penalizer\n",
    "# best_penalizer = find_best_penalizer(df, times, penalizer_range)\n",
    "\n",
    "# # Perform final survival analysis with optimal penalizer\n",
    "# survival_probs, concordance_results = process_for_survival(df, times, best_penalizer)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4a8eb052",
   "metadata": {},
   "source": [
    "# Data Transition"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f4befd25",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "from lifelines import AalenAdditiveFitter\n",
    "\n",
    "# Read data\n",
    "df = pd.read_csv('processed_data.csv')\n",
    "\n",
    "# Define the function to compute the Cox survival model\n",
    "def Survival_T(time, Y, X, A, Delta):  # Cox model\n",
    "    A = np.reshape(A, (A.shape[0], 1))\n",
    "    A = 2 * A - 1  # Make A -1 or 1\n",
    "\n",
    "    # Compute (2A - 1) * X\n",
    "    X_adjusted = np.multiply(X, A.reshape(-1, 1))  # (2A - 1) * X\n",
    "\n",
    "    # Create a DataFrame containing adjusted X and A\n",
    "    data = pd.DataFrame(X_adjusted, columns=[f'X{i + 1}*A' for i in range(X.shape[1])])\n",
    "    pre_data = pd.DataFrame(X_adjusted, columns=[f'X{i + 1}*A' for i in range(X.shape[1])])\n",
    "\n",
    "    # Add survival time and event indicator columns\n",
    "    data['Y'] = Y.flatten()\n",
    "    data['Delta'] = Delta.flatten()\n",
    "\n",
    "    pre_data['Y'] = Y.flatten()\n",
    "    pre_data['Delta'] = Delta.flatten()\n",
    "    \n",
    "    # Fit Cox model with regularization\n",
    "    cox = AalenAdditiveFitter(coef_penalizer=10000)\n",
    "    cox.fit(data, duration_col='Y', event_col='Delta')\n",
    "\n",
    "    # Concordance index\n",
    "    concordance_index = cox.concordance_index_\n",
    "    print(cox.concordance_index_)\n",
    "\n",
    "    # Estimate the survival function at time points\n",
    "    st_estimate = cox.predict_survival_function(pre_data, times=time)\n",
    "    st_estimate_values = st_estimate\n",
    "\n",
    "    return np.log(st_estimate_values), concordance_index\n",
    "\n",
    "\n",
    "# Define a function to handle weekly survival analysis\n",
    "def process_for_survival(df, times):\n",
    "    survival_results = []  # To save survival analysis results\n",
    "    concordance_results = []  # To save concordance index results for each time point\n",
    "\n",
    "    # Loop through data for each week\n",
    "    for week in df['study_week'].unique():\n",
    "        #print(\"Processing week:\", week)\n",
    "        week_data = df[df['study_week'] == week]\n",
    "\n",
    "        X = week_data[['avg_step','avg_sleep']].values\n",
    "        Y = week_data['avg_mood'].values\n",
    "        A = week_data['action'].values\n",
    "        Delta = week_data['censor_mood'].values\n",
    "\n",
    "        # Compute survival function and concordance index\n",
    "        survival_probs, concordance_index = Survival_T(times, Y, X, A, Delta)\n",
    "\n",
    "        # Record the concordance index\n",
    "        concordance_results.append({\n",
    "            'study_week': week,\n",
    "            'concordance_index': concordance_index\n",
    "        })\n",
    "\n",
    "        # Save the survival probability results\n",
    "        for i, participant_id in enumerate(week_data['STUDY_PRTCPT_ID'].values):\n",
    "            row = {\n",
    "                'STUDY_PRTCPT_ID': participant_id,\n",
    "                'study_week': week,\n",
    "                'avg_step': week_data.iloc[i]['avg_step'],\n",
    "                'avg_sleep': week_data.iloc[i]['avg_sleep'],\n",
    "                'avg_mood': Y[i],\n",
    "                'censor_mood': Delta[i],\n",
    "                'action': A[i],\n",
    "            }\n",
    "            for t, prob in zip(times, survival_probs[:, i]):\n",
    "                row[f'survival_probability_time_{t}'] = prob\n",
    "\n",
    "            survival_results.append(row)\n",
    "\n",
    "    # Convert to DataFrame and save survival analysis results\n",
    "    survival_df = pd.DataFrame(survival_results)\n",
    "    survival_df.to_csv('survival_results_aah.csv', index=False)\n",
    "\n",
    "    # Convert to DataFrame and save concordance index results\n",
    "    concordance_df = pd.DataFrame(concordance_results)\n",
    "    concordance_df.to_csv('concordance_index_results_aah.csv', index=False)\n",
    "\n",
    "times = np.arange(5, 71, 5) # For example, time points from 5 to 60 with an interval of 5\n",
    "\n",
    "# Process the data and compute survival probabilities and concordance index\n",
    "process_for_survival(df, times)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c582f8c3",
   "metadata": {},
   "source": [
    "# Cox Model Train and Prediction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e02e0e38",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "import pandas as pd\n",
    "from collections import deque\n",
    "from tqdm import tqdm\n",
    "\n",
    "# Deep Q-Network with improved architecture\n",
    "class DQN(nn.Module):\n",
    "    def __init__(self, input_dim, output_dim):\n",
    "        super(DQN, self).__init__()\n",
    "        self.network = nn.Sequential(\n",
    "            nn.Linear(input_dim, 128),\n",
    "            nn.ReLU(),\n",
    "            nn.BatchNorm1d(128),\n",
    "            nn.Linear(128, 64),\n",
    "            nn.ReLU(),\n",
    "            nn.BatchNorm1d(64),\n",
    "            nn.Linear(64, output_dim)\n",
    "        )\n",
    "    \n",
    "    def forward(self, x):\n",
    "        return self.network(x)\n",
    "\n",
    "# DQN Agent with Experience Replay and Soft Target Updates\n",
    "class DQNAgent:\n",
    "    def __init__(self, state_dim, action_dim, \n",
    "                 gamma=0.99, lr=1e-3, tau=5e-3, \n",
    "                 memory_size=10000, batch_size=64):\n",
    "        self.state_dim = state_dim\n",
    "        self.action_dim = action_dim\n",
    "        self.gamma = gamma  # Discount factor\n",
    "        self.tau = tau      # Target network update rate\n",
    "        \n",
    "        # Policy and target networks\n",
    "        self.policy_net = DQN(state_dim, action_dim)\n",
    "        self.target_net = DQN(state_dim, action_dim)\n",
    "        self.target_net.load_state_dict(self.policy_net.state_dict())\n",
    "        self.target_net.eval()\n",
    "        \n",
    "        # Experience replay buffer\n",
    "        self.memory = deque(maxlen=memory_size)\n",
    "        self.batch_size = batch_size\n",
    "        \n",
    "        # Optimizer and loss function\n",
    "        self.optimizer = optim.Adam(self.policy_net.parameters(), lr=lr)\n",
    "        self.criterion = nn.SmoothL1Loss()  # Huber loss for stability\n",
    "    \n",
    "    def store_experience(self, state, action, reward, next_state, done):\n",
    "        \"\"\"Store experience in replay buffer\"\"\"\n",
    "        self.memory.append((state, action, reward, next_state, done))\n",
    "    \n",
    "    def update_target_network(self):\n",
    "        \"\"\"Soft update of target network parameters\"\"\"\n",
    "        for target_param, policy_param in zip(self.target_net.parameters(), self.policy_net.parameters()):\n",
    "            target_param.data.copy_(self.tau * policy_param.data + (1 - self.tau) * target_param.data)\n",
    "    \n",
    "    def train(self):\n",
    "        \"\"\"Sample a batch and perform a training step\"\"\"\n",
    "        if len(self.memory) < self.batch_size:\n",
    "            return 0.0  # Not enough samples to train\n",
    "        \n",
    "        # Sample batch from replay buffer\n",
    "        batch = np.random.sample(self.memory, self.batch_size)\n",
    "        states, actions, rewards, next_states, dones = zip(*batch)\n",
    "        \n",
    "        # Convert to tensors\n",
    "        states = torch.FloatTensor(states)\n",
    "        actions = torch.LongTensor(actions).unsqueeze(1)\n",
    "        rewards = torch.FloatTensor(rewards).unsqueeze(1)\n",
    "        next_states = torch.FloatTensor(next_states)\n",
    "        dones = torch.FloatTensor(dones).unsqueeze(1)\n",
    "        \n",
    "        # Compute Q-values and target Q-values\n",
    "        current_q = self.policy_net(states).gather(1, actions)\n",
    "        \n",
    "        with torch.no_grad():\n",
    "            next_q = self.target_net(next_states).max(1)[0].unsqueeze(1)\n",
    "            target_q = rewards + self.gamma * next_q * (1 - dones)\n",
    "        \n",
    "        # Calculate loss and optimize\n",
    "        loss = self.criterion(current_q, target_q)\n",
    "        self.optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        self.optimizer.step()\n",
    "        \n",
    "        # Update target network\n",
    "        self.update_target_network()\n",
    "        \n",
    "        return loss.item()\n",
    "    \n",
    "    def select_action(self, state, epsilon=0.1):\n",
    "        \"\"\"Epsilon-greedy action selection\"\"\"\n",
    "        if np.random.random() < epsilon:\n",
    "            return np.random.randint(0, self.action_dim)\n",
    "        else:\n",
    "            state = torch.FloatTensor(state).unsqueeze(0)\n",
    "            with torch.no_grad():\n",
    "                return self.policy_net(state).argmax(1).item()\n",
    "\n",
    "def preprocess_data(file_path):\n",
    "    \"\"\"Load and preprocess data from CSV file\"\"\"\n",
    "    try:\n",
    "        data = pd.read_csv(file_path)\n",
    "        print(f\"Data loaded successfully: {file_path}, Shape: {data.shape}\")\n",
    "        return data\n",
    "    except FileNotFoundError:\n",
    "        print(f\"Error: File {file_path} not found!\")\n",
    "        exit(1)\n",
    "\n",
    "def train_and_evaluate(data, time_point, state_cols, action_col, reward_col_prefix):\n",
    "    \"\"\"Train DQN agent and evaluate performance for a specific time point\"\"\"\n",
    "    # Extract features and normalize\n",
    "    states = data[state_cols].values\n",
    "    states = (states - states.mean(axis=0)) / (states.std(axis=0) + 1e-8)\n",
    "    \n",
    "    # Extract actions and rewards\n",
    "    actions = data[action_col].values\n",
    "    reward_col = f\"{reward_col_prefix}_{time_point}\"\n",
    "    \n",
    "    if reward_col not in data.columns:\n",
    "        print(f\"Warning: Column {reward_col} not found!\")\n",
    "        return None, None, None\n",
    "    \n",
    "    rewards = data[reward_col].values\n",
    "    \n",
    "    # Initialize agent\n",
    "    state_dim = states.shape[1]\n",
    "    action_dim = len(np.unique(actions))\n",
    "    agent = DQNAgent(state_dim, action_dim)\n",
    "    \n",
    "    # Training loop with progress bar\n",
    "    num_episodes = 50\n",
    "    losses = []\n",
    "    \n",
    "    for episode in tqdm(range(num_episodes), desc=f\"Training for time {time_point}\"):\n",
    "        episode_loss = 0\n",
    "        for i in range(len(states) - 1):\n",
    "            state = states[i]\n",
    "            action = actions[i]\n",
    "            reward = rewards[i]\n",
    "            next_state = states[i + 1]\n",
    "            done = i == len(states) - 2\n",
    "            \n",
    "            agent.store_experience(state, action, reward, next_state, done)\n",
    "            loss = agent.train()\n",
    "            episode_loss += loss\n",
    "        \n",
    "        losses.append(episode_loss / (len(states) - 1))\n",
    "    \n",
    "    # Evaluate model\n",
    "    epsilon = 0.0  # Disable exploration during evaluation\n",
    "    predicted_actions = [agent.select_action(s, epsilon) for s in states]\n",
    "    accuracy = np.mean(predicted_actions == actions)\n",
    "    \n",
    "    print(f\"Time {time_point}: Prediction accuracy = {accuracy:.4f}\")\n",
    "    return predicted_actions, accuracy, losses\n",
    "\n",
    "def main():\n",
    "    # Configuration\n",
    "    INPUT_FILE = \"survival_results_aah.csv\"\n",
    "    OUTPUT_FILE = \"rl_prediction_results.csv\"\n",
    "    STATE_COLUMNS = [\"avg_step\", \"avg_sleep\"]\n",
    "    ACTION_COLUMN = \"action\"\n",
    "    REWARD_COLUMN_PREFIX = \"survival_probability_time\"\n",
    "    \n",
    "    # Load and process data\n",
    "    data = preprocess_data(INPUT_FILE)\n",
    "    \n",
    "    # Extract time points from column names\n",
    "    time_points = [int(col.split('_')[-1]) for col in data.columns \n",
    "                  if col.startswith(REWARD_COLUMN_PREFIX)]\n",
    "    \n",
    "    if not time_points:\n",
    "        print(\"No survival probability columns found!\")\n",
    "        return\n",
    "    \n",
    "    print(f\"Found time points: {time_points}\")\n",
    "    \n",
    "    # Results container\n",
    "    performance_metrics = []\n",
    "    \n",
    "    # Process each time point\n",
    "    for time in time_points:\n",
    "        predictions, accuracy, _ = train_and_evaluate(\n",
    "            data, time, STATE_COLUMNS, ACTION_COLUMN, REWARD_COLUMN_PREFIX\n",
    "        )\n",
    "        \n",
    "        if predictions is not None:\n",
    "            data[f'predicted_action_{time}'] = predictions\n",
    "            performance_metrics.append({\n",
    "                'time_point': time,\n",
    "                'accuracy': accuracy\n",
    "            })\n",
    "    \n",
    "    # Save results\n",
    "    data.to_csv(OUTPUT_FILE, index=False)\n",
    "    pd.DataFrame(performance_metrics).to_csv(\"training_metrics.csv\", index=False)\n",
    "    \n",
    "    print(f\"\\nResults saved to {OUTPUT_FILE}\")\n",
    "    print(f\"Training metrics saved to training_metrics.csv\")\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    main()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a713d4f4",
   "metadata": {},
   "source": [
    "# Estimate P with prediction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "79965970",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "from sklearn.preprocessing import MinMaxScaler\n",
    "from lifelines import CoxPHFitter\n",
    "from lifelines import LogNormalAFTFitter\n",
    "\n",
    "# Read data\n",
    "df = pd.read_csv('processed_data.csv')\n",
    "new_df = pd.read_csv('rl_data_aah.csv') \n",
    "t_df = pd.read_csv('rl_data_t.csv')\n",
    "\n",
    "# Define the function to compute the Cox survival model\n",
    "def Survival_T(time, Y, X, A, Delta, new_A=None):  # Cox model\n",
    "    A = np.reshape(A, (A.shape[0], 1))\n",
    "    A = 2 * A - 1  # Make A -1 or 1\n",
    "\n",
    "    # Compute (2A - 1) * X for fitting\n",
    "    X_adjusted_fit = np.multiply(X, A.reshape(-1, 1))  # (2A - 1) * X\n",
    "\n",
    "    # Create a DataFrame containing adjusted X and A for fitting\n",
    "    data = pd.DataFrame(X_adjusted_fit, columns=[f'X{i + 1}*A' for i in range(X.shape[1])])\n",
    "\n",
    "    # Add survival time and event indicator columns\n",
    "    data['Y'] = Y.flatten()\n",
    "    data['Delta'] = Delta.flatten()\n",
    "\n",
    "    # Fit Cox model with regularization\n",
    "    cox = CoxPHFitter(penalizer=0.0157)\n",
    "    cox.fit(data, duration_col='Y', event_col='Delta')\n",
    "\n",
    "    if new_A is not None:\n",
    "        new_A = np.reshape(new_A, (new_A.shape[0], 1))\n",
    "        new_A = 2 * new_A - 1  # Make new_A -1 or 1\n",
    "        X_adjusted_pred = np.multiply(X, new_A.reshape(-1, 1))  # (2new_A - 1) * X\n",
    "        pre_data = pd.DataFrame(X_adjusted_pred, columns=[f'X{i + 1}*A' for i in range(X.shape[1])])\n",
    "    else:\n",
    "        pre_data = pd.DataFrame(X_adjusted_fit, columns=[f'X{i + 1}*A' for i in range(X.shape[1])])\n",
    "\n",
    "    pre_data['Y'] = Y.flatten()\n",
    "    pre_data['Delta'] = Delta.flatten()\n",
    "\n",
    "    # Estimate the survival function at time points\n",
    "    st_estimate = cox.predict_survival_function(pre_data, times=time)\n",
    "\n",
    "    # 将 inf 替换为 -10\n",
    "    result = np.nan_to_num(np.log(st_estimate.values), posinf=-10, neginf=-10)\n",
    "\n",
    "    return result\n",
    "\n",
    "def Survival_T(time, Y, X, A, Delta, new_A=None):  # Cox model\n",
    "\n",
    "    A = np.reshape(A, (A.shape[0], 1))\n",
    "    A = 2 * A - 1  # Make A -1 or 1\n",
    "\n",
    "    # Compute (2A - 1) * X for fitting\n",
    "    X_adjusted_fit = np.multiply(X, A.reshape(-1, 1))  # (2A - 1) * X\n",
    "\n",
    "    # Create a DataFrame containing adjusted X and A for fitting\n",
    "    data = pd.DataFrame(X_adjusted_fit, columns=[f'X{i + 1}*A' for i in range(X.shape[1])])\n",
    "\n",
    "    # Add survival time and event indicator columns\n",
    "    data['Y'] = Y.flatten() + 0.0001\n",
    "    data['Delta'] = Delta.flatten()\n",
    "\n",
    "    # Fit Cox model with regularization\n",
    "    cox = LogNormalAFTFitter()  # Adjust penalizer and l1_ratio as needed\n",
    "    cox.fit(data, duration_col='Y', event_col='Delta')\n",
    "    # print(cox.concordance_index_)\n",
    "    con = con + cox.concordance_index_\n",
    "\n",
    "    if new_A is not None:\n",
    "        new_A = np.reshape(new_A, (new_A.shape[0], 1))\n",
    "        new_A = 2 * new_A - 1  # Make new_A -1 or 1\n",
    "        X_adjusted_pred = np.multiply(X, new_A.reshape(-1, 1))  # (2new_A - 1) * X\n",
    "        pre_data = pd.DataFrame(X_adjusted_pred, columns=[f'X{i + 1}*A' for i in range(X.shape[1])])\n",
    "    else:\n",
    "        pre_data = pd.DataFrame(X_adjusted_fit, columns=[f'X{i + 1}*A' for i in range(X.shape[1])])\n",
    "\n",
    "    pre_data['Y'] = Y.flatten() + 0.0001\n",
    "    pre_data['Delta'] = Delta.flatten()\n",
    "\n",
    "    surv_df = cox.predict_survival_function(pre_data, times=times)\n",
    "    p_matrix = surv_df.T.values  # shape = (n_individuals, n_times)\n",
    "\n",
    "    return p_matrix.T\n",
    "\n",
    "# Define a function to handle weekly survival analysis\n",
    "def process_for_survival(df, new_df, times):\n",
    "    survival_results = []  # To save survival analysis results\n",
    "\n",
    "    # Loop through data for each week\n",
    "    for week in df['study_week'].unique():\n",
    "        week_data = df[df['study_week'] == week]\n",
    "        week_action = new_df[df['study_week'] == week]\n",
    "        ttdf = t_df[t_df['study_week'] == week]\n",
    "        \n",
    "        X = week_data[['avg_step', 'avg_sleep']].values\n",
    "        \n",
    "        Y = week_data['avg_mood'].values\n",
    "        A = week_data['action'].values\n",
    "        Delta = week_data['censor_mood'].values\n",
    "\n",
    "        # Compute survival function for original data\n",
    "        original_st_estimate = Survival_T(times, Y, X, A, Delta)\n",
    "\n",
    "        for t in times:\n",
    "            new_A = week_action[f'predict_action_{t}'].values\n",
    "            new_st_estimate = Survival_T(times, Y, X, A, Delta, new_A)\n",
    "            \n",
    "            predict_action_t = ttdf['predict_action'].values\n",
    "            predict_action_t_st_estimate = Survival_T(times, Y, X, A, Delta, predict_action_t)\n",
    "\n",
    "            for i, participant_id in enumerate(week_data['STUDY_PRTCPT_ID'].values):\n",
    "                row = {\n",
    "                    'STUDY_PRTCPT_ID': participant_id,\n",
    "                    'study_week': week\n",
    "                }\n",
    "                for t, prob in zip(times, original_st_estimate[:, i]):\n",
    "                    row[f'original_st_estimate_{t}'] = prob\n",
    "                for t, prob in zip(times, new_st_estimate[:, i]):\n",
    "                    row[f'new_st_estimate_{t}'] = prob\n",
    "                for t, prob in zip(times, predict_action_t_st_estimate[:, i]):\n",
    "                    row[f'predict_action_t_st_estimate_{t}'] = prob\n",
    "                survival_results.append(row)\n",
    "\n",
    "    # Convert to DataFrame and save survival analysis results\n",
    "    survival_df = pd.DataFrame(survival_results)\n",
    "    survival_df.to_csv('data_aah.csv', index=False)\n",
    "\n",
    "times = np.arange(5, 71, 5)# For example, time points from 5 to 60 with an interval of 5\n",
    "\n",
    "# Process the data and compute survival probabilities\n",
    "process_for_survival(df, new_df, times)\n",
    "    "
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "rl",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
