{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a42cc97e",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "import gym\n",
    "import random\n",
    "from tqdm import trange\n",
    "import matplotlib\n",
    "# matplotlib.use(\"Agg\")  # Non-interactive backend\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "# =============== Common Utility Functions ===============\n",
    "def set_seed(seed):\n",
    "    random.seed(seed)\n",
    "    np.random.seed(seed)\n",
    "    torch.manual_seed(seed)\n",
    "    if torch.cuda.is_available():\n",
    "        torch.cuda.manual_seed_all(seed)\n",
    "\n",
    "# =============== Step 1: Data Preprocessing ===============\n",
    "print(\"🔧 Loading offline dataset...\")\n",
    "df = pd.read_csv(\"offline_dataset.csv\")\n",
    "df[\"state\"] = df[\"state\"].apply(eval)\n",
    "state_dim = len(df[\"state\"].iloc[0])\n",
    "\n",
    "state_mat = np.vstack(df[\"state\"].values)\n",
    "for i in range(state_dim):\n",
    "    df[f\"symptom_{i}\"] = state_mat[:, i]\n",
    "df.drop(columns=[\"state\"], inplace=True)\n",
    "\n",
    "for j in range(state_dim):\n",
    "    df[f\"next_symptom_{j}\"] = np.nan\n",
    "\n",
    "grouped = df.groupby(\"pid\")\n",
    "for pid, traj in grouped:\n",
    "    for i in range(len(traj) - 1):\n",
    "        next_row = traj.iloc[i + 1][[f\"symptom_{j}\" for j in range(state_dim)]].values\n",
    "        for j in range(state_dim):\n",
    "            df.at[traj.index[i], f\"next_symptom_{j}\"] = next_row[j]\n",
    "    for j in range(state_dim):\n",
    "        df.at[traj.index[-1], f\"next_symptom_{j}\"] = traj.iloc[-1][f\"symptom_{j}\"]\n",
    "\n",
    "states = df[[f\"symptom_{i}\" for i in range(state_dim)]].values.astype(np.float32)\n",
    "next_states = df[[f\"next_symptom_{i}\" for i in range(state_dim)]].values.astype(np.float32)\n",
    "actions = df[\"action\"].values.astype(np.int64)\n",
    "rewards = df[\"reward\"].values.astype(np.float32)\n",
    "dones = df[\"done\"].astype(int).values\n",
    "\n",
    "n_actions = df[\"action\"].max() + 1\n",
    "print(f\"Data ready: {states.shape[0]} transitions, state_dim={state_dim}, n_actions={n_actions}\")\n",
    "\n",
    "# =============== Step 2: DQN Definition ===============\n",
    "class DQN(nn.Module):\n",
    "    def __init__(self, state_dim, action_dim):\n",
    "        super().__init__()\n",
    "        self.net = nn.Sequential(\n",
    "            nn.Linear(state_dim, 128),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(128, 128),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(128, action_dim),\n",
    "        )\n",
    "    def forward(self, x):\n",
    "        return self.net(x)\n",
    "\n",
    "def train_dqn(seed, n_epochs=10, batch_size=64, gamma=0.99):\n",
    "    set_seed(seed)\n",
    "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "    policy_net = DQN(state_dim, n_actions).to(device)\n",
    "    optimizer = optim.Adam(policy_net.parameters(), lr=1e-3)\n",
    "    loss_fn = nn.MSELoss()\n",
    "\n",
    "    dataset_size = len(states)\n",
    "    loss_history = []  # Store average loss for each epoch\n",
    "\n",
    "    for epoch in range(n_epochs):\n",
    "        idxs = np.random.permutation(dataset_size)\n",
    "        total_loss = 0\n",
    "        n_batches = 0\n",
    "        for i in range(0, dataset_size, batch_size):\n",
    "            batch_idx = idxs[i:i+batch_size]\n",
    "            s = torch.tensor(states[batch_idx], dtype=torch.float32).to(device)\n",
    "            a = torch.tensor(actions[batch_idx], dtype=torch.int64).to(device)\n",
    "            r = torch.tensor(rewards[batch_idx], dtype=torch.float32).to(device)\n",
    "            ns = torch.tensor(next_states[batch_idx], dtype=torch.float32).to(device)\n",
    "            d = torch.tensor(dones[batch_idx], dtype=torch.float32).to(device)\n",
    "\n",
    "            q_values = policy_net(s).gather(1, a.unsqueeze(1)).squeeze()\n",
    "            with torch.no_grad():\n",
    "                max_next_q = policy_net(ns).max(1)[0]\n",
    "                target = r + gamma * max_next_q * (1 - d)\n",
    "\n",
    "            loss = loss_fn(q_values, target)\n",
    "            optimizer.zero_grad()\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "\n",
    "            total_loss += loss.item()\n",
    "            n_batches += 1\n",
    "        avg_loss = total_loss / n_batches\n",
    "        loss_history.append(avg_loss)\n",
    "        # print(f\"Epoch {epoch+1}/{n_epochs}, Loss={avg_loss:.4f}\")\n",
    "\n",
    "    return policy_net, loss_history\n",
    "\n",
    "\n",
    "# =============== Step 3: Policy Evaluation ===============\n",
    "from epicare import EpiCare\n",
    "# Otherwise replace with \"from epicare import EpiCare\"\n",
    "\n",
    "def evaluate_policy(env, policy_net, n_episodes=200):\n",
    "    device = next(policy_net.parameters()).device\n",
    "    rewards, censored = [], 0\n",
    "    for ep in range(n_episodes):\n",
    "        state = env.reset()\n",
    "        done = False\n",
    "        ep_reward = 0\n",
    "        while not done:\n",
    "            s = torch.tensor(state, dtype=torch.float32).unsqueeze(0).to(device)\n",
    "            with torch.no_grad():\n",
    "                q_vals = policy_net(s)\n",
    "                action = q_vals.argmax(1).item()\n",
    "            state, reward, done, info = env.step(action)\n",
    "            ep_reward += reward\n",
    "            if done and info[\"delta\"] == 0:\n",
    "                censored += 1\n",
    "        rewards.append(ep_reward)\n",
    "    avg_reward = np.mean(rewards)\n",
    "    censor_rate = censored / n_episodes\n",
    "    return avg_reward, censor_rate\n",
    "\n",
    "# =============== Step 4: Multi-seed Experiments ===============\n",
    "seeds = [1,2,3,4,5,6,7,8,9,10]\n",
    "results = []\n",
    "\n",
    "for seed in seeds:\n",
    "    print(f\"\\n🔁 Running with seed={seed}\")\n",
    "    policy_net, loss_history = train_dqn(seed=seed, n_epochs=50)\n",
    "\n",
    "    plt.plot(loss_history)\n",
    "    plt.xlabel(\"Epoch\")\n",
    "    plt.ylabel(\"Loss\")\n",
    "    plt.title(\"DQN Loss Convergence\")\n",
    "    plt.show()\n",
    "    env = gym.make(\"EpiCare-v0\")\n",
    "    avg_reward, censor_rate = evaluate_policy(env, policy_net, n_episodes=200)\n",
    "    print(f\"Seed {seed}: AvgReward={avg_reward:.2f}, CensorRate={censor_rate:.2%}\")\n",
    "    results.append((avg_reward, censor_rate))\n",
    "\n",
    "# Calculate mean and standard deviation\n",
    "rewards = [r for r, _ in results]\n",
    "censors = [c for _, c in results]\n",
    "\n",
    "print(\"\\n Final Statistical Results:\")\n",
    "print(f\"Average Reward = {np.mean(rewards):.2f} ± {np.std(rewards):.2f}\")\n",
    "print(f\"Censor Rate = {np.mean(censors):.2%} ± {np.std(censors):.2%}\")\n",
    "\n",
    "# Save results in numpy format\n",
    "np.save(\"rewards_stats.npy\", {\n",
    "    \"reward_mean\": np.mean(rewards),\n",
    "    \"reward_std\": np.std(rewards),\n",
    "    \"censor_mean\": np.mean(censors),\n",
    "    \"censor_std\": np.std(censors),\n",
    "})\n",
    "\n",
    "print(\"Successfully saved results to rewards_stats.npy\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "rl",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
