{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "01177703",
   "metadata": {},
   "source": [
    "# Data transition"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "15bd3f4e",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "from sklearn.preprocessing import MinMaxScaler\n",
    "from lifelines import KaplanMeierFitter\n",
    "\n",
    "# Read data\n",
    "df = pd.read_csv('processed_data.csv')\n",
    "\n",
    "# Define the function to compute the Cox survival model\n",
    "def Survival_C(time, Y, X, A, Delta):\n",
    "    A = np.reshape(A, (A.shape[0], 1))\n",
    "    A = 2 * A - 1  # Make A -1 or 1\n",
    "    # Compute (2A - 1) * X\n",
    "    X_adjusted = np.multiply(X, A.reshape(-1, 1))  # (2A - 1) * X\n",
    "\n",
    "    # Create a DataFrame containing adjusted X and A\n",
    "    data = pd.DataFrame(X_adjusted, columns=[f'X{i + 1}*A' for i in range(X.shape[1])])\n",
    "    pre_data = pd.DataFrame(X_adjusted, columns=[f'X{i + 1}*A' for i in range(X.shape[1])])\n",
    "\n",
    "    data['Y'] = Y.flatten()\n",
    "    data['1-Delta'] = 1 - Delta.flatten()\n",
    "\n",
    "    pre_data['Y'] = Y.flatten()\n",
    "    pre_data['1-Delta'] = 1 - Delta.flatten()\n",
    "    \n",
    "    # Fit Cox model with regularization\n",
    "    kmf = KaplanMeierFitter()\n",
    "    kmf.fit(durations=data['Y'], event_observed=data['1-Delta'])\n",
    "\n",
    "    sc_estimate = kmf.survival_function_at_times(time)\n",
    "\n",
    "    # print(np.log(st_estimate.values))\n",
    "    return sc_estimate.values\n",
    "\n",
    "# Define a function to handle weekly survival analysis\n",
    "def process_for_survival(df):\n",
    "    survival_results = []  # To save survival analysis results\n",
    "    concordance_results = []  # To save concordance index results for each time point\n",
    "\n",
    "    # Loop through data for each week\n",
    "    for week in df['study_week'].unique():\n",
    "        #print(\"Processing week:\", week)\n",
    "        week_data = df[df['study_week'] == week]\n",
    "\n",
    "        X = week_data[['avg_step', 'avg_sleep']].values\n",
    "        \n",
    "        Y = week_data['avg_mood'].values\n",
    "        A = week_data['action'].values\n",
    "        Delta = week_data['censor_mood'].values\n",
    "\n",
    "        survival_probs = Survival_C(Y, Y, X, A, Delta)\n",
    "        # Save the survival time results\n",
    "        for i, participant_id in enumerate(week_data['STUDY_PRTCPT_ID'].values):\n",
    "            row = {\n",
    "                'STUDY_PRTCPT_ID': participant_id,\n",
    "                'study_week': week,\n",
    "                'avg_step': week_data.iloc[i]['avg_step'],\n",
    "                'avg_sleep': week_data.iloc[i]['avg_sleep'],\n",
    "                'avg_mood': Y[i],\n",
    "                'censor_mood': Delta[i],\n",
    "                'action': A[i],\n",
    "                'survival_time': Y[i] * Delta[i] / survival_probs[i]\n",
    "            }\n",
    "            survival_results.append(row)\n",
    "\n",
    "    # Convert to DataFrame and save survival analysis results\n",
    "    survival_df = pd.DataFrame(survival_results)\n",
    "    survival_df.to_csv('survival_results_t.csv', index=False)\n",
    "\n",
    "# Process the data and compute survival times and concordance index\n",
    "process_for_survival(df)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5e00565e",
   "metadata": {},
   "source": [
    "# Training and predict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c97bcaf5",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "import pandas as pd\n",
    "from collections import deque\n",
    "\n",
    "# Define DQN Network\n",
    "class DQN(nn.Module):\n",
    "    def __init__(self, input_dim, output_dim):\n",
    "        super(DQN, self).__init__()\n",
    "        self.fc1 = nn.Linear(input_dim, 64)\n",
    "        self.fc2 = nn.Linear(64, 64)\n",
    "        self.fc3 = nn.Linear(64, output_dim)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = torch.relu(self.fc1(x))\n",
    "        x = torch.relu(self.fc2(x))\n",
    "        return self.fc3(x)\n",
    "\n",
    "# DQN Agent with Soft Target Network Updates\n",
    "class SoftUpdateDQNAgent:\n",
    "    def __init__(self, input_dim, output_dim, gamma=0.99, tau=0.001, lr=0.001):\n",
    "        self.policy_net = DQN(input_dim, output_dim)  # Policy network for action selection\n",
    "        self.target_net = DQN(input_dim, output_dim)  # Target network for stability\n",
    "        self.target_net.load_state_dict(self.policy_net.state_dict())\n",
    "        self.target_net.eval()  # Set target network to evaluation mode\n",
    "\n",
    "        self.gamma = gamma  # Discount factor for future rewards\n",
    "        self.tau = tau      # Target network update rate\n",
    "        self.optimizer = optim.Adam(self.policy_net.parameters(), lr=lr)\n",
    "        self.criterion = nn.MSELoss()  # Mean squared error loss\n",
    "\n",
    "    def update_target_network(self):\n",
    "        \"\"\"Soft update of target network parameters\"\"\"\n",
    "        for target_param, policy_param in zip(self.target_net.parameters(), self.policy_net.parameters()):\n",
    "            target_param.data.copy_(self.tau * policy_param.data + (1 - self.tau) * target_param.data)\n",
    "\n",
    "    def train(self, state, action, reward, next_state):\n",
    "        \"\"\"Perform a single training step\"\"\"\n",
    "        # Convert to tensors\n",
    "        state = torch.FloatTensor(state)\n",
    "        action = torch.LongTensor([action])\n",
    "        reward = torch.FloatTensor([reward])\n",
    "        next_state = torch.FloatTensor(next_state)\n",
    "\n",
    "        # Compute Q-values\n",
    "        q_values = self.policy_net(state)\n",
    "        q_value = q_values.gather(0, action)\n",
    "\n",
    "        # Compute target Q-value using target network\n",
    "        next_q_values = self.target_net(next_state).detach()\n",
    "        max_next_q_value = next_q_values.max()\n",
    "\n",
    "        # Compute target Q-value\n",
    "        target_q_value = reward + self.gamma * max_next_q_value\n",
    "\n",
    "        # Compute loss and optimize\n",
    "        loss = self.criterion(q_value, target_q_value)\n",
    "        \n",
    "        self.optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        self.optimizer.step()\n",
    "\n",
    "        # Update target network\n",
    "        self.update_target_network()\n",
    "\n",
    "    def predict(self, state):\n",
    "        \"\"\"Predict best action based on current policy\"\"\"\n",
    "        state = torch.FloatTensor(state)\n",
    "        q_values = self.policy_net(state)\n",
    "        action = torch.argmax(q_values).item()\n",
    "        return action\n",
    "\n",
    "def main():\n",
    "    # Specify CSV file path\n",
    "    csv_file = \"survival_results_t.csv\"  # Update with your file path\n",
    "    \n",
    "    # Load data from CSV\n",
    "    try:\n",
    "        data = pd.read_csv(csv_file)\n",
    "        print(f\"Data loaded: {csv_file}, Shape: {data.shape}\")\n",
    "        print(f\"Columns: {data.columns.tolist()}\")\n",
    "    except FileNotFoundError:\n",
    "        print(f\"Error: File {csv_file} not found!\")\n",
    "        return\n",
    "\n",
    "    # Extract relevant data\n",
    "    states = data[['avg_step', 'avg_sleep']].values\n",
    "    actions = data['action'].values\n",
    "\n",
    "    # Check for survival_time column\n",
    "    if 'survival_time' not in data.columns:\n",
    "        print(\"Error: 'survival_time' column not found. Check CSV file format!\")\n",
    "        return\n",
    "\n",
    "    rewards = data['survival_time'].values\n",
    "    \n",
    "    # Create result DataFrame\n",
    "    result_df = data.copy()\n",
    "    \n",
    "    # Initialize agent\n",
    "    input_dim = states.shape[1]\n",
    "    output_dim = len(np.unique(actions))\n",
    "    agent = SoftUpdateDQNAgent(input_dim, output_dim)\n",
    "\n",
    "    # Training loop\n",
    "    for i in range(len(states) - 1):\n",
    "        state = states[i]\n",
    "        action = actions[i]\n",
    "        reward = rewards[i]\n",
    "        next_state = states[i + 1]\n",
    "        agent.train(state, action, reward, next_state)\n",
    "\n",
    "    # Predict actions\n",
    "    predicted_actions = [agent.predict(state) for state in states]\n",
    "\n",
    "    # Calculate prediction accuracy\n",
    "    accuracy = np.mean(np.array(predicted_actions) == actions)\n",
    "    print(f\"Using survival_time as rewards, Prediction accuracy: {accuracy:.4f}\")\n",
    "    \n",
    "    # Add predictions to result DataFrame\n",
    "    result_df['predicted_action'] = predicted_actions\n",
    "\n",
    "    # Save results to new CSV file\n",
    "    output_file = \"rl_data_t.csv\"\n",
    "    result_df.to_csv(output_file, index=False)\n",
    "    print(f\"Prediction results saved to {output_file}\")\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    main()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "rl",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
