{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# OPL Experiment -- varying training data sizes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "from datetime import datetime\n",
    "from pathlib import Path\n",
    "\n",
    "import torch\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from pandas import DataFrame\n",
    "from tqdm import tqdm\n",
    "from sklearn.cluster import KMeans\n",
    "from sklearn.utils import check_random_state\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "import japanize_matplotlib\n",
    "plt.style.use('ggplot')\n",
    "\n",
    "from dataset import process_kuairec_data\n",
    "from policylearners import RegBasedPolicyLearner, GradientBasedPolicyLearner, POTEC\n",
    "from utils import learn_q_model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Simulation Configurations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_runs = 30 # number of simulations\n",
    "max_iter = 20 # number of epochs\n",
    "num_clusters = 30 # number of candidate actions\n",
    "beta = 5 # noise parameter of the logging policy\n",
    "tau = 2 # temperature parameter of the logging policy\n",
    "test_data_size = 100000 # test data size\n",
    "random_state = 12345\n",
    "torch.manual_seed(random_state)\n",
    "random_ = check_random_state(random_state)\n",
    "num_data_list = [2000, 4000, 8000, 16000, 32000] # training data sizes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "result_path = Path(f\"./result/{datetime.now().strftime('%Y-%m-%d')}/train_data\")\n",
    "result_path.mkdir(parents=True, exist_ok=True)\n",
    "result_file_name = f\"result_{datetime.now().strftime('%H:00')}.csv\"\n",
    "curve_file_name = f\"curve_{datetime.now().strftime('%H:00')}.csv\"\n",
    "print(result_file_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "num_data=2000...: 100%|██████████| 15/15 [8:41:52<00:00, 2087.50s/it]  \n",
      "num_data=4000...: 100%|██████████| 15/15 [1:41:13<00:00, 404.90s/it]\n",
      "num_data=8000...: 100%|██████████| 15/15 [1:45:49<00:00, 423.30s/it]\n",
      "num_data=16000...: 100%|██████████| 15/15 [1:55:00<00:00, 460.04s/it]\n",
      "num_data=32000...: 100%|██████████| 15/15 [2:18:55<00:00, 555.69s/it] \n"
     ]
    }
   ],
   "source": [
    "test_data = process_kuairec_data(num_data=test_data_size, beta=beta, tau=tau, random_state=random_state)\n",
    "dim_x, num_actions = test_data[\"x\"].shape[1], test_data[\"num_actions\"]\n",
    "pi_0_value = (test_data[\"q_x_a\"] * test_data[\"pi_0\"]).sum(1).mean()\n",
    "\n",
    "for num_data in num_data_list:\n",
    "    curve = DataFrame()\n",
    "    for _ in tqdm(range(num_runs), desc=f\"num_data={num_data}...\"):\n",
    "        _ += 10\n",
    "        ## generate offline logged data\n",
    "        train_data = process_kuairec_data(num_data=num_data, beta=beta, tau=tau, random_state=_)\n",
    "\n",
    "        true_value_of_learned_policies = dict()\n",
    "        true_value_of_learned_policies[\"logging\"] = pi_0_value\n",
    "        log_ = DataFrame([[pi_0_value] * max_iter, [\"logging\"] * (max_iter + 1)], index=[\"value\", \"method\"]).T.reset_index()\n",
    "\n",
    "        ## perform OPL based on logged data\n",
    "        ### regression-based approach\n",
    "        reg = RegBasedPolicyLearner(dim_x=dim_x, num_actions=num_actions, max_iter=max_iter, random_state=_)\n",
    "        reg.fit(train_data, test_data)\n",
    "        pi_reg = reg.predict(test_data)\n",
    "        true_value_of_learned_policies[\"reg\"] = (test_data[\"q_x_a\"] * pi_reg).sum(1).mean()\n",
    "        reg_ = DataFrame([reg.test_value, [\"reg\"] * (max_iter + 1)], index=[\"value\", \"method\"]).T.reset_index()\n",
    "        ### gradient-based approach w/ IPS\n",
    "        ips = GradientBasedPolicyLearner(dim_x=dim_x, num_actions=num_actions, max_iter=max_iter, random_state=_)\n",
    "        ips.fit(train_data, test_data)\n",
    "        pi_ips = ips.predict(test_data)\n",
    "        true_value_of_learned_policies[\"ips-pg\"] = (test_data[\"q_x_a\"] * pi_ips).sum(1).mean()\n",
    "        ips_ = DataFrame([ips.test_value, [\"ips\"] * (max_iter + 1)], index=[\"value\", \"method\"]).T.reset_index()\n",
    "        ### gradient-based approach w/ DR\n",
    "        dr = GradientBasedPolicyLearner(dim_x=dim_x, num_actions=num_actions, max_iter=max_iter, random_state=_)\n",
    "        dr.fit(train_data, test_data, q_hat=reg.predict_q(train_data))\n",
    "        pi_dr = dr.predict(test_data)\n",
    "        true_value_of_learned_policies[\"dr-pg\"] = (test_data[\"q_x_a\"] * pi_dr).sum(1).mean()\n",
    "        dr_ = DataFrame([dr.test_value, [\"dr\"] * (max_iter + 1)], index=[\"value\", \"method\"]).T.reset_index()\n",
    "        ### POTEC algorithm\n",
    "        lam_ = lam - 0.015 if num_data > 10000 else lam\n",
    "        q_hat_train, q_hat_test = learn_q_model(train_data, test_data, lam=lam_)\n",
    "        km = KMeans(n_clusters=num_clusters)\n",
    "        train_data[\"phi_a\"] = km.fit_predict(q_hat_train.mean(0)[:, np.newaxis])\n",
    "        test_data[\"phi_a\"] = km.predict(q_hat_test.mean(0)[:, np.newaxis])\n",
    "        potec = POTEC(dim_x=dim_x, num_actions=num_actions, num_clusters=num_clusters, max_iter=max_iter, random_state=_)\n",
    "        potec.fit(train_data, test_data, f_hat=q_hat_train, f_hat_test=q_hat_test)\n",
    "        pi_potec = potec.predict(test_data, f_hat_test=q_hat_test)\n",
    "        true_value_of_learned_policies[\"potec\"] = (test_data[\"q_x_a\"] * pi_potec).sum(1).mean()\n",
    "        potec_ = DataFrame([potec.test_value, [\"potec\"] * (max_iter + 1)], index=[\"value\", \"method\"]).T.reset_index()\n",
    "\n",
    "        df_seed = DataFrame(true_value_of_learned_policies, index=[\"value\"]).T\n",
    "        df_seed = df_seed.reset_index().rename(columns={\"index\": \"method\"})\n",
    "        df_seed[\"seed\"] = _\n",
    "        df_seed[\"num_data\"] = num_data\n",
    "        df_seed[\"pi_0_value\"] = pi_0_value\n",
    "        df_seed[\"rel_value\"] = df_seed[\"value\"] / pi_0_value\n",
    "        if (result_path / result_file_name).exists():\n",
    "            df_ = pd.read_csv(result_path / result_file_name, index_col=0)\n",
    "            df_ = pd.concat([df_, df_seed]).reset_index(drop=True)\n",
    "            df_.to_csv(result_path / result_file_name)\n",
    "        else:\n",
    "            df_seed.to_csv(result_path / result_file_name)\n",
    "\n",
    "        curve = pd.concat([reg_, ips_, dr_, potec_]).rename(columns={\"index\": \"epoch\"})\n",
    "        curve[\"seed\"] = _\n",
    "        curve[\"num_data\"] = num_data\n",
    "        curve[\"rel_value\"] = curve.value / pi_0_value\n",
    "        if (result_path / curve_file_name).exists():\n",
    "            curve_ = pd.read_csv(result_path / curve_file_name, index_col=0)\n",
    "            curve_ = pd.concat([curve_, curve]).reset_index(drop=True)\n",
    "            curve_.to_csv(result_path / curve_file_name)\n",
    "        else:\n",
    "            curve.to_csv(result_path / curve_file_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
