{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%reload_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import warnings\n",
    "warnings.filterwarnings('ignore')\n",
    "from logging import getLogger\n",
    "from pathlib import Path\n",
    "import os\n",
    "import sys\n",
    "sys.path.append(os.pardir)\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd \n",
    "from tqdm import tqdm\n",
    "from sklearn.utils import check_random_state\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "from utils import fix_seed, empty_metrics\n",
    "from run import run_dynamic_match\n",
    "\n",
    "from synthetic_data import generate_data, generate_reward_data, train_model\n",
    "import conf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "variable = \"n_xy\"\n",
    "T=conf.T"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "logger = getLogger(__name__)\n",
    "logger.info(f\"The current working directory is {Path().cwd()}\")\n",
    "\n",
    "# log path\n",
    "log_path = Path(f\"../result/{variable}\")\n",
    "df_path = log_path / \"df\"\n",
    "df_path.mkdir(exist_ok=True, parents=True)\n",
    "\n",
    "\n",
    "# DataFrame to store results of all seeds\n",
    "all_data = pd.DataFrame(columns=[\"seed\", variable, \"t\", \"method\", \"match_x\", \"match_y\", \"active_users_x\", \"active_users_y\", \"user_retain_x\", \"user_retain_y\", \"true_user_retain_x\", \"true_user_retain_y\"])\n",
    "\n",
    "\n",
    "# for seed in range(num_seeds):\n",
    "for seed in tqdm(range(conf.num_seeds), desc=\"Processing seeds\"):\n",
    "    for n_xy in tqdm(conf.n_xy_list, desc=f\"Processing {variable}\"):\n",
    "\n",
    "        n_x=n_xy//2\n",
    "        n_y=n_xy//2\n",
    "        \n",
    "        random_ = check_random_state(conf.random_state + 1 + seed)\n",
    "        results = {method: empty_metrics(conf.T, n_x, n_y) for method in conf.method_list}\n",
    "\n",
    "        reward_data = generate_reward_data(\n",
    "            dim = conf.dim,\n",
    "            T = conf.T,\n",
    "            alpha_param=conf.alpha_param,\n",
    "            beta_param=conf.beta_param,\n",
    "            random_state = conf.random_state)\n",
    "\n",
    "        model = train_model(\n",
    "            dim = conf.dim,\n",
    "            T = conf.T,\n",
    "            n_train = conf.n_train,\n",
    "            reward_data = reward_data,\n",
    "            alpha_param=conf.alpha_param,\n",
    "            beta_param=conf.beta_param,\n",
    "            random_state = conf.random_state + 1 + seed)\n",
    "        \n",
    "        dataset = generate_data(\n",
    "            n_x = n_x,\n",
    "            n_y = n_y,\n",
    "            dim = conf.dim,\n",
    "            rel_noise = conf.rel_noise,\n",
    "            T = conf.T,\n",
    "            K = conf.K,\n",
    "            kappa=conf.kappa,\n",
    "            reward_data=reward_data,\n",
    "            alpha_param=conf.alpha_param,\n",
    "            beta_param=conf.beta_param,\n",
    "            random_state = conf.random_state + 1 + seed,\n",
    "            random_=random_,\n",
    "            ) \n",
    "        \n",
    "\n",
    "        run_dynamic_match(\n",
    "            dataset, \n",
    "            model=model,\n",
    "            proportion=conf.proportion,\n",
    "            reward_type=conf.reward_type,\n",
    "            ranking_metric=conf.ranking_metric, \n",
    "            results=results, \n",
    "            noise=conf.noise,\n",
    "            random_state=conf.random_state+1+seed, \n",
    "        )\n",
    "        \n",
    "\n",
    "        temp_data = []\n",
    "        for method, metrics in results.items():\n",
    "            for t in range(1, T):\n",
    "                temp_data.append({\n",
    "                    \"seed\": seed,\n",
    "                    variable: n_xy,\n",
    "                    \"t\": t,\n",
    "                    \"method\": method,\n",
    "                    \"match_x\": metrics[\"match_x\"][t].mean(),\n",
    "                    \"match_y\": metrics[\"match_y\"][t].mean(),\n",
    "                    \"exposure_x\": metrics[\"exposure_x\"][t].mean(),\n",
    "                    \"exposure_y\": metrics[\"exposure_y\"][t].mean(),\n",
    "                    \"fair_x\": metrics[\"fair_x\"][t].mean(),\n",
    "                    \"fair_y\": metrics[\"fair_y\"][t].mean(),\n",
    "                    \"active_users_x\": metrics[\"active_users_x\"][t].mean(),\n",
    "                    \"active_users_y\": metrics[\"active_users_y\"][t].mean(),\n",
    "                    \"user_retain_x\": metrics[\"user_retain_x\"][t].mean(),\n",
    "                    \"user_retain_y\": metrics[\"user_retain_y\"][t].mean(),\n",
    "                    \"true_user_retain_x\": metrics[\"true_user_retain_x\"][t].mean(),\n",
    "                    \"true_user_retain_y\": metrics[\"true_user_retain_y\"][t].mean(),\n",
    "                    \"mse_user_retain_x\": ((metrics[\"true_user_retain_x\"][t]-metrics[\"user_retain_x\"])**2).mean(),\n",
    "                    \"mse_user_retain_y\": ((metrics[\"true_user_retain_y\"][t]-metrics[\"user_retain_y\"])**2).mean(),\n",
    "                })\n",
    "\n",
    "\n",
    "        all_data = pd.concat([all_data, pd.DataFrame(temp_data)], ignore_index=True)\n",
    "\n",
    "# Save to CSV\n",
    "all_data[\"t\"] = pd.to_numeric(all_data[\"t\"], errors=\"coerce\")\n",
    "all_data.to_csv(df_path / \"all_data_results.csv\", index=False)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "all_data = pd.read_csv(df_path / \"all_data_results.csv\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from visualization_seed_variable import run_visual\n",
    "\n",
    "\n",
    "run_visual(\n",
    "    all_data=all_data,\n",
    "    variable=variable,\n",
    "    n_x=conf.n_x,\n",
    "    n_y=conf.n_y,\n",
    "    T=conf.T,\n",
    ")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.6"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
