{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "400d00c8",
   "metadata": {},
   "source": [
    "# Data Transition"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "99976526",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "from lifelines import AalenAdditiveFitter, CoxPHFitter, KaplanMeierFitter\n",
    "\n",
    "# ================== Load Raw Data ==================\n",
    "df = pd.read_csv(\"offline_dataset.csv\")\n",
    "\n",
    "# -------- First perform min-max normalization on rewards by individual & stage to ensure all rewards > 0 --------\n",
    "eps = 1e-6\n",
    "if \"id\" in df.columns:  # If individual ID exists\n",
    "    df[\"reward\"] = df.groupby([\"id\", \"stage\"])[\"reward\"].transform(\n",
    "        lambda x: (x - x.min()) / (x.max() - x.min() + eps) + eps\n",
    "    )\n",
    "else:  # If no individual ID, normalize within each stage\n",
    "    df[\"reward\"] = df.groupby(\"stage\")[\"reward\"].transform(\n",
    "        lambda x: (x - x.min()) / (x.max() - x.min() + eps) + eps\n",
    "    )\n",
    "\n",
    "# ================== Extract Columns ==================\n",
    "state_cols = [c for c in df.columns if c.startswith(\"symptom_\")]\n",
    "\n",
    "# DataFrames for storing results\n",
    "df_aalen = df.copy()\n",
    "df_cox = df.copy()\n",
    "df_km = df.copy()\n",
    "\n",
    "# Define multiple time points\n",
    "time_points = np.round(np.arange(0.1, 1.01, 0.1), 1)  # [0.1, 0.2, ..., 1.0]\n",
    "\n",
    "# Pre-generate columns\n",
    "for t in time_points:\n",
    "    df_aalen[f\"new_reward_t{int(t*10):02d}\"] = np.nan\n",
    "    df_cox[f\"new_reward_t{int(t*10):02d}\"] = np.nan\n",
    "df_km[\"new_reward\"] = np.nan\n",
    "\n",
    "# ================== Fit & Predict by Stage ==================\n",
    "for stage, group in df.groupby(\"stage\"):\n",
    "    idx = group.index\n",
    "    Y = group[\"reward\"].values.reshape(-1, 1)\n",
    "    Delta = group[\"delta\"].values.reshape(-1, 1)\n",
    "    A = group[\"action\"].values.reshape(-1, 1)\n",
    "    X = group[state_cols].values\n",
    "\n",
    "    # Combine features: X, A, X*A\n",
    "    XA = np.hstack([X, A, X * A])\n",
    "    colnames = [f\"X{i}\" for i in range(X.shape[1])] + [\"A\"] + [f\"X{i}*A\" for i in range(X.shape[1])]\n",
    "\n",
    "    # -------- Aalen --------\n",
    "    try:\n",
    "        data = pd.DataFrame(XA, columns=colnames)\n",
    "        data[\"Y\"] = Y.flatten()\n",
    "        data[\"Delta\"] = Delta.flatten()\n",
    "        pre_data = data.copy()\n",
    "\n",
    "        aaf = AalenAdditiveFitter()\n",
    "        aaf.fit(data, duration_col=\"Y\", event_col=\"Delta\")\n",
    "\n",
    "        for t in time_points:\n",
    "            st_est = aaf.predict_survival_function(pre_data, times=[t])\n",
    "            reward_stage = np.log(st_est.values.reshape(-1) + 1e-5)\n",
    "            df_aalen.loc[idx, f\"new_reward_t{int(t*10):02d}\"] = reward_stage\n",
    "    except Exception as e:\n",
    "        print(f\"[Aalen] Stage {stage} error: {e}\")\n",
    "\n",
    "    # -------- Cox --------\n",
    "    try:\n",
    "        data = pd.DataFrame(XA, columns=colnames)\n",
    "        data[\"Y\"] = Y.flatten()\n",
    "        data[\"Delta\"] = Delta.flatten()\n",
    "        pre_data = data.copy()\n",
    "\n",
    "        cph = CoxPHFitter()\n",
    "        cph.fit(data, duration_col=\"Y\", event_col=\"Delta\")\n",
    "\n",
    "        for t in time_points:\n",
    "            st_est = cph.predict_survival_function(pre_data, times=[t])\n",
    "            reward_stage = np.log(st_est.values.reshape(-1) + 1e-5)\n",
    "            df_cox.loc[idx, f\"new_reward_t{int(t*10):02d}\"] = reward_stage\n",
    "    except Exception as e:\n",
    "        print(f\"[Cox] Stage {stage} error: {e}\")\n",
    "\n",
    "    # -------- KM --------\n",
    "    try:\n",
    "        data = pd.DataFrame(XA, columns=colnames)\n",
    "        data[\"Y\"] = Y.flatten()\n",
    "        data[\"event\"] = Delta.flatten()\n",
    "\n",
    "        kmf = KaplanMeierFitter()\n",
    "        kmf.fit(durations=data[\"Y\"], event_observed=data[\"event\"])\n",
    "        S = kmf.predict(data[\"Y\"])\n",
    "        reward_stage = np.zeros_like(Y.flatten())\n",
    "        non_zero_indices = Delta.flatten() != 0\n",
    "        eps = 1e-3  # Prevent division by zero\n",
    "        reward_stage[non_zero_indices] = (\n",
    "            (Y.flatten()[non_zero_indices] * Delta.flatten()[non_zero_indices]) / (S.values[non_zero_indices] + eps)\n",
    "        )\n",
    "        df_km.loc[idx, \"new_reward\"] = reward_stage\n",
    "    except Exception as e:\n",
    "        print(f\"[KM] Stage {stage} error: {e}\")\n",
    "\n",
    "# ================== Save Results ==================\n",
    "df_aalen.to_csv(\"offline_dataset_aalen.csv\", index=False)\n",
    "df_cox.to_csv(\"offline_dataset_cox.csv\", index=False)\n",
    "df_km.to_csv(\"offline_dataset_km.csv\", index=False)\n",
    "\n",
    "print(\"Generated new reward tables by stage:\")\n",
    "print(\" - offline_dataset_aalen.csv (with columns for multiple time points)\")\n",
    "print(\" - offline_dataset_cox.csv (with columns for multiple time points)\")\n",
    "print(\" - offline_dataset_km.csv (single column)\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6fa9912d",
   "metadata": {},
   "source": [
    "# Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "17486748",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import gym\n",
    "from epicare import EpiCare\n",
    "from dqn_utils import load_dataset, train_dqn, evaluate_policy\n",
    "import pandas as pd\n",
    "\n",
    "# Three reward files\n",
    "datasets = {\n",
    "    \"Aalen\": \"offline_dataset_aalen.csv\",\n",
    "    \"Cox\": \"offline_dataset_cox.csv\",\n",
    "    \"KM\": \"offline_dataset_km.csv\",\n",
    "}\n",
    "\n",
    "seeds = [1,2,3,4,5,6,7,8,9,10]\n",
    "\n",
    "# Iterate through different reward files\n",
    "for name, csv_file in datasets.items():\n",
    "    print(f\"\\n================= {name} Reward =================\")\n",
    "    df = pd.read_csv(csv_file)\n",
    "\n",
    "    # --- KM has only one new_reward column ---\n",
    "    if name == \"KM\":\n",
    "        states, next_states, actions, rewards, dones, state_dim, n_actions = load_dataset(csv_file)\n",
    "\n",
    "        results = []\n",
    "        for seed in seeds:\n",
    "            print(f\"\\n🔁 Running with seed={seed}\")\n",
    "            policy_net, loss_history = train_dqn(seed, states, next_states, actions, rewards, dones,\n",
    "                                                 state_dim, n_actions, n_epochs=50)\n",
    "\n",
    "            # plt.plot(loss_history)\n",
    "            # plt.xlabel(\"Epoch\")\n",
    "            # plt.ylabel(\"Loss\")\n",
    "            # plt.title(f\"{name} DQN Loss (seed={seed})\")\n",
    "            # plt.show()\n",
    "\n",
    "            env = gym.make(\"EpiCare-v0\")\n",
    "            avg_reward, censor_rate = evaluate_policy(env, policy_net, n_episodes=200)\n",
    "            print(f\"Seed {seed}: AvgReward={avg_reward:.2f}, CensorRate={censor_rate:.2%}\")\n",
    "            results.append((avg_reward, censor_rate))\n",
    "\n",
    "        rewards_list = [r for r, _ in results]\n",
    "        censors_list = [c for _, c in results]\n",
    "        reward_mean, reward_std = np.mean(rewards_list), np.std(rewards_list)\n",
    "        censor_mean, censor_std = np.mean(censors_list), np.std(censors_list)\n",
    "\n",
    "        stats = {\n",
    "            \"reward_mean\": reward_mean,\n",
    "            \"reward_std\": reward_std,\n",
    "            \"censor_mean\": censor_mean,\n",
    "            \"censor_std\": censor_std,\n",
    "        }\n",
    "        print(f\"\\n📊 {name} Reward Final Statistical Results: {stats}\")\n",
    "        np.save(f\"results_{name}.npy\", stats)\n",
    "        print(f\"Successfully saved {name} results to results_{name}.npy\")\n",
    "\n",
    "    # --- Cox / Aalen have multiple time points ---\n",
    "    else:\n",
    "        time_cols = [c for c in df.columns if c.startswith(\"new_reward_t\")]\n",
    "        results_all = {}\n",
    "\n",
    "        for col in time_cols:\n",
    "            print(f\"\\n===== {name} using {col} =====\")\n",
    "            # Temporarily rename to \"new_reward\" for load_dataset usage\n",
    "            df_tmp = df.copy()\n",
    "            df_tmp = df_tmp.rename(columns={col: \"new_reward\"})\n",
    "            tmp_file = f\"tmp_{name}_{col}.csv\"\n",
    "            df_tmp.to_csv(tmp_file, index=False)\n",
    "\n",
    "            states, next_states, actions, rewards, dones, state_dim, n_actions = load_dataset(tmp_file)\n",
    "\n",
    "            results = []\n",
    "            for seed in seeds:\n",
    "                print(f\"\\n🔁 Running with seed={seed}\")\n",
    "                policy_net, loss_history = train_dqn(seed, states, next_states, actions, rewards, dones,\n",
    "                                                     state_dim, n_actions, n_epochs=50)\n",
    "\n",
    "                # plt.plot(loss_history)\n",
    "                # plt.xlabel(\"Epoch\")\n",
    "                # plt.ylabel(\"Loss\")\n",
    "                # plt.title(f\"{name} ({col}) DQN Loss (seed={seed})\")\n",
    "                # plt.show()\n",
    "\n",
    "                env = gym.make(\"EpiCare-v0\")\n",
    "                avg_reward, censor_rate = evaluate_policy(env, policy_net, n_episodes=200)\n",
    "                print(f\"Seed {seed}: AvgReward={avg_reward:.2f}, CensorRate={censor_rate:.2%}\")\n",
    "                results.append((avg_reward, censor_rate))\n",
    "\n",
    "            rewards_list = [r for r, _ in results]\n",
    "            censors_list = [c for _, c in results]\n",
    "            reward_mean, reward_std = np.mean(rewards_list), np.std(rewards_list)\n",
    "            censor_mean, censor_std = np.mean(censors_list), np.std(censors_list)\n",
    "\n",
    "            results_all[col] = {\n",
    "                \"reward_mean\": reward_mean,\n",
    "                \"reward_std\": reward_std,\n",
    "                \"censor_mean\": censor_mean,\n",
    "                \"censor_std\": censor_std,\n",
    "            }\n",
    "            print(f\"\\n📊 {name} ({col}) Reward Statistical Results: {results_all[col]}\")\n",
    "\n",
    "        # Save the complete time-point results dictionary\n",
    "        np.save(f\"results_{name}_all.npy\", results_all)\n",
    "        print(f\"Successfully saved all {name} time-point results to results_{name}_all.npy\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "rl",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
