{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# OPL Experiment -- varying number of clusters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "from copy import deepcopy\n",
    "from datetime import datetime\n",
    "from pathlib import Path\n",
    "\n",
    "import torch\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from pandas import DataFrame\n",
    "from tqdm import tqdm\n",
    "from sklearn.cluster import KMeans\n",
    "from sklearn.utils import check_random_state\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "import japanize_matplotlib\n",
    "plt.style.use('ggplot')\n",
    "\n",
    "from dataset import generate_synthetic_data\n",
    "from policylearners import RegBasedPolicyLearner, GradientBasedPolicyLearner, POTEC\n",
    "from utils import learn_q_model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Simulation Configurations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_runs = 30 # number of simulations\n",
    "max_iter = 20 # number of epochs\n",
    "dim_x = 5 # dim context\n",
    "num_data = 4000 # number of actions\n",
    "num_actions = 500 # number of candidate actions\n",
    "beta = 5 # noise parameter of the logging policy\n",
    "tau = 2 # temperature parameter of the logging policy\n",
    "test_data_size = 100000 # test data size\n",
    "random_state = 12345\n",
    "torch.manual_seed(random_state)\n",
    "random_ = check_random_state(random_state)\n",
    "num_clusters_list = [5, 10, 20, 50, 100, 200] # training data sizes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "result_20:00.csv\n"
     ]
    }
   ],
   "source": [
    "result_path = Path(f\"./result/{datetime.now().strftime('%Y-%m-%d')}/num_clusters\")\n",
    "result_path.mkdir(parents=True, exist_ok=True)\n",
    "result_file_name = f\"result_{datetime.now().strftime('%H:00')}.csv\"\n",
    "curve_file_name = f\"curve_{datetime.now().strftime('%H:00')}.csv\"\n",
    "print(result_file_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "num_clusters=5...: 100%|██████████| 20/20 [20:18<00:00, 60.94s/it]\n",
      "num_clusters=10...: 100%|██████████| 20/20 [20:45<00:00, 62.30s/it]\n",
      "num_clusters=20...: 100%|██████████| 20/20 [21:35<00:00, 64.76s/it]\n",
      "num_clusters=50...: 100%|██████████| 20/20 [24:50<00:00, 74.54s/it]\n",
      "num_clusters=100...: 100%|██████████| 20/20 [30:13<00:00, 90.66s/it]\n",
      "num_clusters=200...: 100%|██████████| 20/20 [29:07<00:00, 87.39s/it]\n"
     ]
    }
   ],
   "source": [
    "for num_clusters in num_clusters_list:\n",
    "    phi_a = random_.choice(num_clusters, size=num_actions)\n",
    "    theta_g = random_.normal(size=(dim_x, num_clusters))\n",
    "    M_g = random_.normal(size=(dim_x, num_clusters))\n",
    "    b_g = random_.normal(size=(1, num_clusters))\n",
    "    theta_h = random_.normal(size=(dim_x, num_actions))\n",
    "    M_h = random_.normal(size=(dim_x, dim_x))\n",
    "    b_h = random_.normal(size=(1, num_actions))\n",
    "\n",
    "    test_data = generate_synthetic_data(\n",
    "        num_data=test_data_size,\n",
    "        theta_g=theta_g, M_g=M_g, b_g=b_g, theta_h=theta_h, M_h=M_h, b_h=b_h, phi_a=phi_a,\n",
    "        dim_context=dim_x, num_actions=num_actions, num_clusters=num_clusters, random_state = random_state\n",
    "    )\n",
    "    pi_0_value = (test_data[\"q_x_a\"] * test_data[\"pi_0\"]).sum(1).mean()\n",
    "\n",
    "    curve = DataFrame()\n",
    "    for _ in tqdm(range(num_runs), desc=f\"num_clusters={num_clusters}...\"):\n",
    "        _ += 10\n",
    "        ## generate offline logged data\n",
    "        train_data = generate_synthetic_data(\n",
    "            num_data=num_data,\n",
    "            theta_g=theta_g, M_g=M_g, b_g=b_g, theta_h=theta_h, M_h=M_h, b_h=b_h, phi_a=phi_a,\n",
    "            dim_context=dim_x, num_actions=num_actions, num_clusters=num_clusters,\n",
    "            random_state = _\n",
    "        )\n",
    "\n",
    "        true_value_of_learned_policies = dict()\n",
    "        true_value_of_learned_policies[\"logging\"] = pi_0_value\n",
    "        log_ = DataFrame([[pi_0_value] * max_iter, [\"logging\"] * (max_iter + 1)], index=[\"value\", \"method\"]).T.reset_index()\n",
    "\n",
    "        ## perform OPL based on logged data\n",
    "        ### regression-based approach\n",
    "        reg = RegBasedPolicyLearner(dim_x=dim_x, num_actions=num_actions, max_iter=max_iter, random_state=_)\n",
    "        reg.fit(train_data, test_data)\n",
    "        pi_reg = reg.predict(test_data)\n",
    "        true_value_of_learned_policies[\"reg\"] = (test_data[\"q_x_a\"] * pi_reg).sum(1).mean()\n",
    "        reg_ = DataFrame([reg.test_value, [\"reg\"] * (max_iter + 1)], index=[\"value\", \"method\"]).T.reset_index()\n",
    "        ### gradient-based approach w/ IPS\n",
    "        ips = GradientBasedPolicyLearner(dim_x=dim_x, num_actions=num_actions, max_iter=max_iter, random_state=_)\n",
    "        ips.fit(train_data, test_data)\n",
    "        pi_ips = ips.predict(test_data)\n",
    "        true_value_of_learned_policies[\"ips-pg\"] = (test_data[\"q_x_a\"] * pi_ips).sum(1).mean()\n",
    "        ips_ = DataFrame([ips.test_value, [\"ips\"] * (max_iter + 1)], index=[\"value\", \"method\"]).T.reset_index()\n",
    "        ### gradient-based approach w/ DR\n",
    "        dr = GradientBasedPolicyLearner(dim_x=dim_x, num_actions=num_actions, max_iter=max_iter, random_state=_)\n",
    "        dr.fit(train_data, test_data, q_hat=reg.predict_q(train_data))\n",
    "        pi_dr = dr.predict(test_data)\n",
    "        true_value_of_learned_policies[\"dr-pg\"] = (test_data[\"q_x_a\"] * pi_dr).sum(1).mean()\n",
    "        dr_ = DataFrame([dr.test_value, [\"dr\"] * (max_iter + 1)], index=[\"value\", \"method\"]).T.reset_index()\n",
    "        ### POTEC algorithm w/ true cluster & pairwise\n",
    "        q_hat_train_, q_hat_test_ = learn_q_model(train_data, test_data, lam=1)\n",
    "        # lam = lam_dict[num_clusters]\n",
    "        lam_ = lam - 0.085\n",
    "        q_hat_train = deepcopy((1 - lam_) * train_data[\"q_x_a\"] + lam_ * q_hat_train_)\n",
    "        q_hat_test = deepcopy((1 - lam_) * test_data[\"q_x_a\"] + lam_ * q_hat_test_)\n",
    "        potec = POTEC(dim_x=dim_x, num_actions=num_actions, num_clusters=num_clusters, max_iter=max_iter, random_state=_)\n",
    "        potec.fit(train_data, test_data, f_hat=q_hat_train, f_hat_test=q_hat_test)\n",
    "        pi_potec_true_pair = potec.predict(test_data, f_hat_test=q_hat_test)\n",
    "        true_value_of_learned_policies[\"potec_true_pair\"] = (test_data[\"q_x_a\"] * pi_potec_true_pair).sum(1).mean()\n",
    "        potec_true_pair = DataFrame([potec.test_value, [\"potec_true_pair\"] * (max_iter + 1)], index=[\"value\", \"method\"]).T.reset_index()\n",
    "        ### POTEC algorithm w/ true cluster & pairwise\n",
    "        q_hat_train = deepcopy((1 - lam) * train_data[\"q_x_a\"] + lam * q_hat_train_)\n",
    "        q_hat_test = deepcopy((1 - lam) * test_data[\"q_x_a\"] + lam * q_hat_test_)\n",
    "        potec = POTEC(dim_x=dim_x, num_actions=num_actions, num_clusters=num_clusters, max_iter=max_iter, random_state=_)\n",
    "        potec.fit(train_data, test_data, f_hat=q_hat_train, f_hat_test=q_hat_test)\n",
    "        pi_potec_true_abs = potec.predict(test_data, f_hat_test=q_hat_test)\n",
    "        true_value_of_learned_policies[\"potec_true_abs\"] = (test_data[\"q_x_a\"] * pi_potec_true_abs).sum(1).mean()\n",
    "        potec_true_abs = DataFrame([potec.test_value, [\"potec_true_abs\"] * (max_iter + 1)], index=[\"value\", \"method\"]).T.reset_index()\n",
    "        ### POTEC algorithm w/ learned cluster\n",
    "        km = KMeans(n_clusters=num_clusters)\n",
    "        train_data[\"phi_a\"] = km.fit_predict(q_hat_train.mean(0)[:, np.newaxis])\n",
    "        test_data[\"phi_a\"] = km.predict(q_hat_test.mean(0)[:, np.newaxis])\n",
    "        potec = POTEC(dim_x=dim_x, num_actions=num_actions, num_clusters=num_clusters, max_iter=max_iter, random_state=_)\n",
    "        potec.fit(train_data, test_data, f_hat=q_hat_train, f_hat_test=q_hat_test)\n",
    "        pi_potec_learned = potec.predict(test_data, f_hat_test=q_hat_test)\n",
    "        true_value_of_learned_policies[\"potec_learned\"] = (test_data[\"q_x_a\"] * pi_potec_learned).sum(1).mean()\n",
    "        potec_learned = DataFrame([potec.test_value, [\"potec_learned\"] * (max_iter + 1)], index=[\"value\", \"method\"]).T.reset_index()\n",
    "\n",
    "        df_seed = DataFrame(true_value_of_learned_policies, index=[\"value\"]).T\n",
    "        df_seed = df_seed.reset_index().rename(columns={\"index\": \"method\"})\n",
    "        df_seed[\"seed\"] = _\n",
    "        df_seed[\"num_clusters\"] = num_clusters\n",
    "        df_seed[\"pi_0_value\"] = pi_0_value\n",
    "        df_seed[\"rel_value\"] = df_seed[\"value\"] / pi_0_value\n",
    "        if (result_path / result_file_name).exists():\n",
    "            df_ = pd.read_csv(result_path / result_file_name, index_col=0)\n",
    "            df_ = pd.concat([df_, df_seed]).reset_index(drop=True)\n",
    "            df_.to_csv(result_path / result_file_name)\n",
    "        else:\n",
    "            df_seed.to_csv(result_path / result_file_name)\n",
    "\n",
    "        curve = pd.concat([reg_, ips_, dr_, potec_true_pair, potec_true_abs, potec_learned]).rename(columns={\"index\": \"epoch\"})\n",
    "        curve[\"seed\"] = _\n",
    "        curve[\"num_clusters\"] = num_clusters\n",
    "        curve[\"rel_value\"] = curve.value / pi_0_value\n",
    "        if (result_path / curve_file_name).exists():\n",
    "            curve_ = pd.read_csv(result_path / curve_file_name, index_col=0)\n",
    "            curve_ = pd.concat([curve_, curve]).reset_index(drop=True)\n",
    "            curve_.to_csv(result_path / curve_file_name)\n",
    "        else:\n",
    "            curve.to_csv(result_path / curve_file_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
