{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [],
   "source": [
    "%reload_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import warnings\n",
    "warnings.filterwarnings('ignore')\n",
    "from logging import getLogger\n",
    "from pathlib import Path\n",
    "import os\n",
    "import sys\n",
    "sys.path.append(os.pardir)\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd \n",
    "from tqdm import tqdm\n",
    "from sklearn.utils import check_random_state\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "from utils import fix_seed, empty_metrics\n",
    "from run import run_dynamic_match\n",
    "\n",
    "\n",
    "from synthetic_data import generate_data, generate_reward_data, train_model\n",
    "import conf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "variable = \"lambda_\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "T=conf.T\n",
    "n_x=conf.n_x\n",
    "n_y=conf.n_y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "logger = getLogger(__name__)\n",
    "logger.info(f\"The current working directory is {Path().cwd()}\")\n",
    "\n",
    "# log path\n",
    "log_path = Path(f\"../result/{variable}\")\n",
    "df_path = log_path / \"df\"\n",
    "df_path.mkdir(exist_ok=True, parents=True)\n",
    "\n",
    "\n",
    "# DataFrame to store results of all seeds\n",
    "all_data = pd.DataFrame(columns=[\"seed\", variable, \"t\", \"method\", \"match_x\", \"match_y\", \"active_users_x\", \"active_users_y\", \"user_retain_x\", \"user_retain_y\", \"true_user_retain_x\", \"true_user_retain_y\"])\n",
    "\n",
    "\n",
    "for seed in tqdm(range(conf.num_seeds), desc=\"Processing seeds\"):\n",
    "    random_ = check_random_state(conf.random_state + 1 + seed)\n",
    "    reward_data = generate_reward_data(\n",
    "        dim = conf.dim,\n",
    "        T = conf.T,\n",
    "        alpha_param=conf.alpha_param,\n",
    "        beta_param=conf.beta_param,\n",
    "        random_state = conf.random_state)\n",
    "\n",
    "    model = train_model(\n",
    "        dim = conf.dim,\n",
    "        T = conf.T,\n",
    "        n_train = conf.n_train,\n",
    "        reward_data = reward_data,\n",
    "        alpha_param=conf.alpha_param,\n",
    "        beta_param=conf.beta_param,\n",
    "        random_state = conf.random_state + 1 + seed)\n",
    "    \n",
    "    dataset = generate_data(\n",
    "        n_x = conf.n_x,\n",
    "        n_y = conf.n_y,\n",
    "        dim = conf.dim,\n",
    "        rel_noise = conf.rel_noise,\n",
    "        T = conf.T,\n",
    "        K = conf.K,\n",
    "        kappa=conf.kappa,\n",
    "        reward_data=reward_data,\n",
    "        alpha_param=conf.alpha_param,\n",
    "        beta_param=conf.beta_param,\n",
    "        random_state = conf.random_state + 1 + seed,\n",
    "        random_=random_,\n",
    "        ) \n",
    "    \n",
    "\n",
    "    static_results_by_method = {}\n",
    "    for method in conf.method_list:\n",
    "        if method == \"FairCo\":\n",
    "            continue\n",
    "        # Train fixed model\n",
    "\n",
    "        # Empty result dict\n",
    "        results = {method: empty_metrics(conf.T, conf.n_x, conf.n_y)}\n",
    "        run_dynamic_match(\n",
    "            dataset,\n",
    "            model=model,\n",
    "            proportion=conf.proportion,\n",
    "            reward_type=conf.reward_type,\n",
    "            ranking_metric=conf.ranking_metric,\n",
    "            results=results,\n",
    "            noise=conf.noise,\n",
    "            random_state=conf.random_state + 1 + seed,\n",
    "        )\n",
    "        static_results_by_method[method] = results[method]\n",
    "\n",
    "    # === Execute for all lambdas ===\n",
    "    for lambda_ in tqdm(conf.lambda_list, desc=\"Processing lambda\"):\n",
    "        temp_data = []\n",
    "        for method in conf.method_list:\n",
    "            if method == \"FairCo\":\n",
    "                results = {method: empty_metrics(conf.T, conf.n_x, conf.n_y)}\n",
    "                run_dynamic_match(\n",
    "                    dataset,\n",
    "                    model=model,\n",
    "                    proportion=conf.proportion,\n",
    "                    reward_type=conf.reward_type,\n",
    "                    ranking_metric=conf.ranking_metric,\n",
    "                    results=results,\n",
    "                    noise=conf.noise,\n",
    "                    lambda_=lambda_,\n",
    "                    random_state=conf.random_state + 1 + seed,\n",
    "                )\n",
    "                metrics = results[method]\n",
    "            else:\n",
    "                metrics = static_results_by_method[method]\n",
    "\n",
    "            # Record results (format as before)\n",
    "            for t in range(1, conf.T):\n",
    "                temp_data.append({\n",
    "                    \"seed\": seed,\n",
    "                    variable: lambda_,\n",
    "                    \"t\": t,\n",
    "                    \"method\": method,\n",
    "                    \"match_x\": metrics[\"match_x\"][t].mean(),\n",
    "                    \"match_y\": metrics[\"match_y\"][t].mean(),\n",
    "                    \"exposure_x\": metrics[\"exposure_x\"][t].mean(),\n",
    "                    \"exposure_y\": metrics[\"exposure_y\"][t].mean(),\n",
    "                    \"fair_x\": metrics[\"fair_x\"][t].mean(),\n",
    "                    \"fair_y\": metrics[\"fair_y\"][t].mean(),\n",
    "                    \"active_users_x\": metrics[\"active_users_x\"][t].mean(),\n",
    "                    \"active_users_y\": metrics[\"active_users_y\"][t].mean(),\n",
    "                    \"user_retain_x\": metrics[\"user_retain_x\"][t].mean(),\n",
    "                    \"user_retain_y\": metrics[\"user_retain_y\"][t].mean(),\n",
    "                    \"true_user_retain_x\": metrics[\"true_user_retain_x\"][t].mean(),\n",
    "                    \"true_user_retain_y\": metrics[\"true_user_retain_y\"][t].mean(),\n",
    "                    \"mse_user_retain_x\": ((metrics[\"true_user_retain_x\"][t] - metrics[\"user_retain_x\"])**2).mean(),\n",
    "                    \"mse_user_retain_y\": ((metrics[\"true_user_retain_y\"][t] - metrics[\"user_retain_y\"])**2).mean(),\n",
    "                })\n",
    "\n",
    "        all_data = pd.concat([all_data, pd.DataFrame(temp_data)], ignore_index=True)\n",
    "\n",
    "# Save\n",
    "all_data[\"t\"] = pd.to_numeric(all_data[\"t\"], errors=\"coerce\")\n",
    "all_data.to_csv(df_path / \"all_data_results.csv\", index=False)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [],
   "source": [
    "all_data = pd.read_csv(df_path / \"all_data_results.csv\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from visualization_seed_variable import run_visual\n",
    "\n",
    "run_visual(\n",
    "    all_data=all_data,\n",
    "    variable=variable,\n",
    "    n_x=conf.n_x,\n",
    "    n_y=conf.n_y,\n",
    "    T=conf.T,\n",
    "    x_log_scale=True\n",
    ")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.6"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
