{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [],
   "source": [
    "%reload_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import warnings\n",
    "warnings.filterwarnings('ignore')\n",
    "from logging import getLogger\n",
    "from pathlib import Path\n",
    "import os\n",
    "import sys\n",
    "sys.path.append(os.pardir)\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd \n",
    "from tqdm import tqdm\n",
    "from sklearn.utils import check_random_state\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "from utils import fix_seed, empty_metrics\n",
    "from run import run_dynamic_match\n",
    "from visualization_seed import plot_match_per, plot_number_user_retain, plot_user_retain\n",
    "\n",
    "from synthetic_data import generate_data, generate_reward_data, train_model\n",
    "import conf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [],
   "source": [
    "conf.T = 50\n",
    "conf.n_x = 20\n",
    "conf.n_y = 20\n",
    "conf.K = 3\n",
    "conf.candidate_retention = 0.02\n",
    "conf.method_list = ['MRet (best)','optimal_ranking']\n",
    "conf.show_method_list = conf.method_list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [],
   "source": [
    "reward_data = generate_reward_data(\n",
    "    dim = conf.dim,\n",
    "    T = conf.T,\n",
    "    alpha_param=conf.alpha_param,\n",
    "    beta_param=conf.beta_param,\n",
    "    random_state = conf.random_state)\n",
    "\n",
    "model = train_model(\n",
    "    dim = conf.dim,\n",
    "    T = conf.T,\n",
    "    n_train = conf.n_train,\n",
    "    reward_data = reward_data,\n",
    "    alpha_param=conf.alpha_param,\n",
    "    beta_param=conf.beta_param,\n",
    "    random_state = conf.random_state)\n",
    "\n",
    "T=conf.T\n",
    "n_x=conf.n_x\n",
    "n_y=conf.n_y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "logger = getLogger(__name__)\n",
    "logger.info(f\"The current working directory is {Path().cwd()}\")\n",
    "\n",
    "# log path\n",
    "log_path = Path(\"../result/T_optimal\")\n",
    "df_path = log_path / \"df\"\n",
    "df_path.mkdir(exist_ok=True, parents=True)\n",
    "\n",
    "\n",
    "# DataFrame to store results of all seeds\n",
    "all_data = pd.DataFrame(columns=[\"seed\", \"t\", \"method\", \"match_x\", \"match_y\", \"active_users_x\", \"active_users_y\", \"user_retain_x\", \"user_retain_y\", \"true_user_retain_x\", \"true_user_retain_y\"])\n",
    "\n",
    "for seed in tqdm(range(conf.num_seeds), desc=\"Processing seeds\"):\n",
    "    random_ = check_random_state(conf.random_state + 1 + seed)\n",
    "    results = {method: empty_metrics(conf.T, n_x, n_y) for method in conf.method_list}\n",
    "\n",
    "    dataset = generate_data(\n",
    "        n_x = conf.n_x,\n",
    "        n_y = conf.n_y,\n",
    "        dim = conf.dim,\n",
    "        rel_noise = conf.rel_noise,\n",
    "        T = conf.T,\n",
    "        K = conf.K,\n",
    "        kappa=conf.kappa,\n",
    "        reward_data=reward_data,\n",
    "        alpha_param=conf.alpha_param,\n",
    "        beta_param=conf.beta_param,\n",
    "        random_state = conf.random_state + 1 + seed,\n",
    "        random_=random_,\n",
    "        ) \n",
    "    \n",
    "    run_dynamic_match(\n",
    "        dataset, \n",
    "        model=model,\n",
    "        proportion=conf.proportion,\n",
    "        reward_type=conf.reward_type,\n",
    "        ranking_metric=conf.ranking_metric, \n",
    "        results=results, \n",
    "        noise=conf.noise,\n",
    "        candidate_retention=conf.candidate_retention,\n",
    "        random_state=conf.random_state+1+seed, \n",
    "    )\n",
    "    \n",
    "\n",
    "    temp_data = []\n",
    "    for method, metrics in results.items():\n",
    "        for t in range(1, conf.T):\n",
    "            temp_data.append({\n",
    "                \"seed\": seed,\n",
    "                \"t\": t,\n",
    "                \"method\": method,\n",
    "                \"match_x\": metrics[\"match_x\"][t].mean(),\n",
    "                \"match_y\": metrics[\"match_y\"][t].mean(),\n",
    "                \"exposure_x\": metrics[\"exposure_x\"][t].mean(),\n",
    "                \"exposure_y\": metrics[\"exposure_y\"][t].mean(),\n",
    "                \"fair_x\": metrics[\"fair_x\"][t].mean(),\n",
    "                \"fair_y\": metrics[\"fair_y\"][t].mean(),\n",
    "                \"active_users_x\": metrics[\"active_users_x\"][t].mean(),\n",
    "                \"active_users_y\": metrics[\"active_users_y\"][t].mean(),\n",
    "                \"user_retain_x\": metrics[\"user_retain_x\"][t].mean(),\n",
    "                \"user_retain_y\": metrics[\"user_retain_y\"][t].mean(),\n",
    "                \"true_user_retain_x\": metrics[\"true_user_retain_x\"][t].mean(),\n",
    "                \"true_user_retain_y\": metrics[\"true_user_retain_y\"][t].mean(),\n",
    "            })\n",
    "\n",
    "    # Concatenate temporary DataFrame\n",
    "    all_data = pd.concat([all_data, pd.DataFrame(temp_data)], ignore_index=True)\n",
    "\n",
    "# Save to CSV\n",
    "all_data[\"t\"] = pd.to_numeric(all_data[\"t\"], errors=\"coerce\")\n",
    "all_data.to_csv(df_path / \"all_data_results.csv\", index=False)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axes = plt.subplots(1, 2, figsize=(18, 6))\n",
    "plot_match_per(all_data, ax=axes[0], side=\"both\", n_x=n_x, n_y=n_y)\n",
    "plot_number_user_retain(all_data, ax=axes[1], side=\"both\", n_x=n_x, n_y=n_y)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.6"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
