{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# OPL Experiment -- varying numbers of clusters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "from datetime import datetime\n",
    "from pathlib import Path\n",
    "\n",
    "import torch\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from pandas import DataFrame\n",
    "from tqdm import tqdm\n",
    "from sklearn.cluster import KMeans\n",
    "from sklearn.utils import check_random_state\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "import japanize_matplotlib\n",
    "plt.style.use('ggplot')\n",
    "\n",
    "from dataset import process_kuairec_data\n",
    "from policylearners import RegBasedPolicyLearner, GradientBasedPolicyLearner, POTEC\n",
    "from utils import learn_q_model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Simulation Configurations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_runs = 30 # number of simulations\n",
    "max_iter = 20 # number of epochs\n",
    "num_data = 8000 # training data size\n",
    "beta = 5 # noise parameter of the logging policy\n",
    "tau = 2 # temperature parameter of the logging policy\n",
    "test_data_size = 100000 # test data size\n",
    "random_state = 12345\n",
    "torch.manual_seed(random_state)\n",
    "random_ = check_random_state(random_state)\n",
    "num_clusters_list = [10, 20, 50, 100, 200, 500, 1000] # training data sizes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "result_path = Path(f\"./result/{datetime.now().strftime('%Y-%m-%d')}/num_clusters\")\n",
    "result_path.mkdir(parents=True, exist_ok=True)\n",
    "result_file_name = f\"result_{datetime.now().strftime('%H:00')}.csv\"\n",
    "curve_file_name = f\"curve_{datetime.now().strftime('%H:00')}.csv\"\n",
    "print(result_file_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "num_clusters=10...: 100%|██████████| 15/15 [2:10:08<00:00, 520.56s/it] \n",
      "num_clusters=20...: 100%|██████████| 15/15 [1:52:14<00:00, 448.98s/it]\n",
      "num_clusters=50...: 100%|██████████| 15/15 [1:57:39<00:00, 470.65s/it]\n",
      "num_clusters=100...: 100%|██████████| 15/15 [1:54:42<00:00, 458.83s/it]\n",
      "num_clusters=200...: 100%|██████████| 15/15 [1:54:49<00:00, 459.30s/it]\n",
      "num_clusters=500...: 100%|██████████| 15/15 [1:54:34<00:00, 458.30s/it]\n",
      "num_clusters=1000...: 100%|██████████| 15/15 [1:53:10<00:00, 452.68s/it]\n"
     ]
    }
   ],
   "source": [
    "for num_clusters in num_clusters_list:\n",
    "    test_data = process_kuairec_data(num_data=test_data_size, beta=beta, tau=tau, random_state=random_state)\n",
    "    dim_x, num_actions = test_data[\"x\"].shape[1], test_data[\"num_actions\"]\n",
    "    pi_0_value = (test_data[\"q_x_a\"] * test_data[\"pi_0\"]).sum(1).mean()\n",
    "\n",
    "    curve = DataFrame()\n",
    "    for _ in tqdm(range(num_runs), desc=f\"num_clusters={num_clusters}...\"):\n",
    "        ## generate offline logged data\n",
    "        train_data = process_kuairec_data(num_data=num_data, beta=beta, tau=tau, random_state=_)\n",
    "\n",
    "        true_value_of_learned_policies = dict()\n",
    "        true_value_of_learned_policies[\"logging\"] = pi_0_value\n",
    "        log_ = DataFrame([[pi_0_value] * max_iter, [\"logging\"] * (max_iter + 1)], index=[\"value\", \"method\"]).T.reset_index()\n",
    "\n",
    "        ## perform OPL based on logged data\n",
    "        ### regression-based approach\n",
    "        reg = RegBasedPolicyLearner(dim_x=dim_x, num_actions=num_actions, max_iter=max_iter, random_state=_)\n",
    "        reg.fit(train_data, test_data)\n",
    "        pi_reg = reg.predict(test_data)\n",
    "        true_value_of_learned_policies[\"reg\"] = (test_data[\"q_x_a\"] * pi_reg).sum(1).mean()\n",
    "        reg_ = DataFrame([reg.test_value, [\"reg\"] * (max_iter + 1)], index=[\"value\", \"method\"]).T.reset_index()\n",
    "        ### gradient-based approach w/ IPS\n",
    "        ips = GradientBasedPolicyLearner(dim_x=dim_x, num_actions=num_actions, max_iter=max_iter, random_state=_)\n",
    "        ips.fit(train_data, test_data)\n",
    "        pi_ips = ips.predict(test_data)\n",
    "        true_value_of_learned_policies[\"ips-pg\"] = (test_data[\"q_x_a\"] * pi_ips).sum(1).mean()\n",
    "        ips_ = DataFrame([ips.test_value, [\"ips\"] * (max_iter + 1)], index=[\"value\", \"method\"]).T.reset_index()\n",
    "        ### gradient-based approach w/ DR\n",
    "        dr = GradientBasedPolicyLearner(dim_x=dim_x, num_actions=num_actions, max_iter=max_iter, random_state=_)\n",
    "        dr.fit(train_data, test_data, q_hat=reg.predict_q(train_data))\n",
    "        pi_dr = dr.predict(test_data)\n",
    "        true_value_of_learned_policies[\"dr-pg\"] = (test_data[\"q_x_a\"] * pi_dr).sum(1).mean()\n",
    "        dr_ = DataFrame([dr.test_value, [\"dr\"] * (max_iter + 1)], index=[\"value\", \"method\"]).T.reset_index()\n",
    "        ### POTEC algorithm\n",
    "        q_hat_train, q_hat_test = learn_q_model(train_data, test_data, lam=lam_dict[num_clusters])\n",
    "        km = KMeans(n_clusters=num_clusters)\n",
    "        train_data[\"phi_a\"] = km.fit_predict(q_hat_train.mean(0)[:, np.newaxis])\n",
    "        test_data[\"phi_a\"] = km.predict(q_hat_test.mean(0)[:, np.newaxis])\n",
    "        potec = POTEC(dim_x=dim_x, num_actions=num_actions, num_clusters=num_clusters, max_iter=max_iter, random_state=_)\n",
    "        potec.fit(train_data, test_data, f_hat=q_hat_train, f_hat_test=q_hat_test)\n",
    "        pi_potec = potec.predict(test_data, f_hat_test=q_hat_test)\n",
    "        true_value_of_learned_policies[\"potec\"] = (test_data[\"q_x_a\"] * pi_potec).sum(1).mean()\n",
    "        potec_ = DataFrame([potec.test_value, [\"potec\"] * (max_iter + 1)], index=[\"value\", \"method\"]).T.reset_index()\n",
    "\n",
    "        df_seed = DataFrame(true_value_of_learned_policies, index=[\"value\"]).T\n",
    "        df_seed = df_seed.reset_index().rename(columns={\"index\": \"method\"})\n",
    "        df_seed[\"seed\"] = _\n",
    "        df_seed[\"num_clusters\"] = num_clusters\n",
    "        df_seed[\"pi_0_value\"] = pi_0_value\n",
    "        df_seed[\"rel_value\"] = df_seed[\"value\"] / pi_0_value\n",
    "        if (result_path / result_file_name).exists():\n",
    "            df_ = pd.read_csv(result_path / result_file_name, index_col=0)\n",
    "            df_ = pd.concat([df_, df_seed]).reset_index(drop=True)\n",
    "            df_.to_csv(result_path / result_file_name)\n",
    "        else:\n",
    "            df_seed.to_csv(result_path / result_file_name)\n",
    "\n",
    "        curve = pd.concat([reg_, ips_, dr_, potec_]).rename(columns={\"index\": \"epoch\"})\n",
    "        curve[\"seed\"] = _\n",
    "        curve[\"num_clusters\"] = num_clusters\n",
    "        curve[\"rel_value\"] = curve.value / pi_0_value\n",
    "        if (result_path / curve_file_name).exists():\n",
    "            curve_ = pd.read_csv(result_path / curve_file_name, index_col=0)\n",
    "            curve_ = pd.concat([curve_, curve]).reset_index(drop=True)\n",
    "            curve_.to_csv(result_path / curve_file_name)\n",
    "        else:\n",
    "            curve.to_csv(result_path / curve_file_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
