{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from cancer import CancerSim, PolicyCancer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm.notebook import tqdm\n",
    "import matplotlib.pyplot as plt\n",
    "import pickle"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def collect_data(n, ttype=\"PCV\", penalty=1.0):\n",
    "    episodes_per_policy = 10\n",
    "    env = CancerSim(dose_penalty=penalty, therapy_type=ttype, env_seed=None, max_steps=30)\n",
    "\n",
    "    dataset = {'observations': np.zeros((n, env.max_steps, env.state_dim)),\n",
    "               'actions': np.zeros((n, env.max_steps, 1), np.int),\n",
    "               'rewards': np.zeros((n, env.max_steps, 1)),\n",
    "               'sparse_rewards': np.zeros((n, env.max_steps, 1)),\n",
    "               'dense_rewards': np.zeros((n, env.max_steps, 1)),\n",
    "               'not_done': np.zeros((n, env.max_steps, 1)),\n",
    "               'pibs': np.zeros((n, env.max_steps, env.num_actions)),\n",
    "               'nn_action_dist': np.ones((n, env.max_steps, env.num_actions)) * 1e9,\n",
    "               }\n",
    "\n",
    "    for i in range(n):\n",
    "        if (i % episodes_per_policy) == 0:\n",
    "            policy = PolicyCancer(months_for_treatment=9, eps_behavior=0.3)\n",
    "\n",
    "        env = CancerSim(dose_penalty=penalty, therapy_type=ttype, env_seed=None, max_steps=30)\n",
    "        obs = env.reset()\n",
    "        t = 0\n",
    "        done = False\n",
    "        while not done:\n",
    "            a = policy(obs, t)\n",
    "            new_obs, rt, done, _ = env.step(a)\n",
    "\n",
    "            dataset['observations'][i, t, :] = obs\n",
    "            dataset['actions'][i, t, :] = a\n",
    "            dataset['rewards'][i, t, :] = rt\n",
    "            dataset['not_done'][i, t, :] = float(1-done)\n",
    "            dataset['pibs'][i, t, :] = policy.return_probs(obs, t)\n",
    "            \n",
    "            obs = new_obs\n",
    "            t += 1\n",
    "\n",
    "    return dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def process_data(dataset1, dataset2=None):\n",
    "    if dataset2 is None:\n",
    "        dataset2 = dataset1\n",
    "    n = dataset1['observations'].shape[0]\n",
    "    horizon = dataset1['observations'].shape[1]\n",
    "    state_dim = dataset1['observations'].shape[-1]\n",
    "    num_actions = dataset1[\"pibs\"].shape[-1]\n",
    "    \n",
    "    trees = dict()\n",
    "    for a in range(num_actions):\n",
    "        trees[a] = NearestNeighbors(n_neighbors=1)\n",
    "        is_action = (dataset2['actions'][:, :, 0] == a)\n",
    "        trees[a].fit(dataset2[\"observations\"][is_action, :])\n",
    "    \n",
    "    X = dataset1[\"observations\"].reshape(-1,state_dim)\n",
    "    for a in range(num_actions):\n",
    "        dists = (trees[a].kneighbors(X)[0]**2)/state_dim\n",
    "        dataset1['nn_action_dist'][:, :, a] = dists.reshape(n, horizon)\n",
    "    return dataset1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in range(1,6):\n",
    "    dataset_pcv_high = collect_data(1000)\n",
    "    dataset_pcv_high = process_data(dataset_pcv_high)\n",
    "    \n",
    "    dataset_pcv_high_val = collect_data(1000)\n",
    "    dataset_pcv_high_val = process_data(dataset_pcv_high_val, dataset_pcv_high)\n",
    "    \n",
    "    with open(f\"../../data/cancer_pcv{i}_train_episodes\", 'wb') as f:\n",
    "    pickle.dump(dataset_pcv_high, f)\n",
    "    \n",
    "    with open(f\"../../data/cancer_pcv{i}_val_episodes\", 'wb') as f:\n",
    "    pickle.dump(dataset_pcv_high_val, f)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
