{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "866ac7eb",
   "metadata": {},
   "outputs": [],
   "source": [
    "import argparse\n",
    "import os\n",
    "import sys\n",
    "import math\n",
    "from itertools import product\n",
    "import time\n",
    "from pathlib import Path\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib\n",
    "from scipy import stats\n",
    "import torch\n",
    "from torch import nn\n",
    "import torch.nn.functional as F\n",
    "import seaborn as sns\n",
    "from datetime import datetime\n",
    "from tqdm import trange\n",
    "from scipy.stats import poisson, nbinom\n",
    "import importlib, temp_estimator\n",
    "importlib.reload(temp_estimator) \n",
    "from temp_estimator import *\n",
    "matplotlib.rcParams['pdf.fonttype'] = 42\n",
    "matplotlib.rcParams['ps.fonttype'] = 42"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f4c8eed1",
   "metadata": {},
   "outputs": [],
   "source": [
    "def f1(z, rate = None):\n",
    "    if not torch.is_tensor(z):\n",
    "        z = torch.tensor(z)\n",
    "    return z\n",
    "    \n",
    "def f2(z, rate = None):\n",
    "    if not torch.is_tensor(z):\n",
    "        z = torch.tensor(z)\n",
    "    # modify to make it can apply to float\n",
    "    return (z + 1e-9)**0.5\n",
    "\n",
    "def f3(z, rate = None):\n",
    "    if not torch.is_tensor(z):\n",
    "        z = torch.tensor(z)\n",
    "    return z**2\n",
    "\n",
    "def f4(z, rate = None):\n",
    "    if not torch.is_tensor(z):\n",
    "        z = torch.tensor(z)\n",
    "    return z**1.5 - 2 * z\n",
    "\n",
    "def f5(z, rate = None):\n",
    "    if not torch.is_tensor(z):\n",
    "        z = torch.tensor(z)\n",
    "    return torch.cos(z) ** 2\n",
    "\n",
    "def f6(z, rate = None):\n",
    "    if not torch.is_tensor(z):\n",
    "        z = torch.tensor(z)\n",
    "    return torch.sigmoid(z)\n",
    "\n",
    "def f7(z, rate = None): \n",
    "    if not torch.is_tensor(z):\n",
    "        z = torch.tensor(z)\n",
    "    return z**2/(rate + 1e-6)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f07d1667",
   "metadata": {},
   "outputs": [],
   "source": [
    "f_map = {\n",
    "    \"z\": f1,\n",
    "    \"z^0.5\": f2,\n",
    "    \"z^2\": f3,\n",
    "    \"z^1.5-2z\": f4,\n",
    "    \"cos^2(z)\": f5,\n",
    "    \"sigmoid(z)\": f6,\n",
    "    \"z^2(rate)^-1\": f7,\n",
    "}\n",
    "base = Path(r\"your dir\")\n",
    "selected_n_mc = [1, 2, 5, 10, 20, 50, 100, 200, 500]\n",
    "\n",
    "fn_latex_alias = {\n",
    "    \"z\": r\"$z$\",\n",
    "    \"z^0.5\": r\"$\\sqrt{z}$\",\n",
    "    \"z^2\": r\"$z^2$\",\n",
    "    \"z^1.5-2z\": r\"$z^{1.5}-2z$\",\n",
    "    \"cos^2(z)\": r\"$\\cos^2(z)$\",\n",
    "    \"sigmoid(z)\": r\"$\\mathrm{sigmoid}(z)$\",\n",
    "    \"z^2(rate)^-1\": r\"$z^2/\\lambda$\",\n",
    "}\n",
    "\n",
    "method_alias = {\n",
    "    \"exp\": \"EAT_sigmoid\",\n",
    "    \"cubic_exp\": \"EAT_cubic\",\n",
    "    \"GS\": \"Gumbel-Softmax\",\n",
    "    \"score\": \"Score\",\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "935e415a",
   "metadata": {},
   "outputs": [],
   "source": [
    "max_tau=1.0\n",
    "method_list = [\"GS\", \"exp\", \"cubic_exp\"]\n",
    "dfs = []\n",
    "for (fname, f), target_method in product(f_map.items(), method_list):\n",
    "    matches = sorted(base.glob(f\"df_result_poisson_{fname}_*.csv\"))\n",
    "    if not matches:\n",
    "        print(f\"[skip] no file for {fname}\")\n",
    "        continue\n",
    "\n",
    "    latest = matches[-1]\n",
    "    print(f\"Loading: {latest}\")\n",
    "    df_loaded = pd.read_csv(latest)\n",
    "    df_proc, _ = compute_mae_and_best_tau(df_loaded, target_method, f, max_tau=max_tau)\n",
    "    df_proc[\"fn_name\"] = fname\n",
    "    dfs.append(df_proc)\n",
    "\n",
    "if not dfs:\n",
    "    raise ValueError(\"No files loaded.\")\n",
    "df_all = pd.concat(dfs, ignore_index=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "de4467c1",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_best_tau_vs_rate(df, target_method, ax):\n",
    "    target_n_mc = 5\n",
    "    sub = df[(df[\"n_monte_carlo\"] == target_n_mc) & (df[\"method\"] == target_method)].copy()\n",
    "    agg = (\n",
    "        sub.groupby([\"fn_name\", \"rate\", \"best_tau\"])[\"mae\"]\n",
    "        .agg(mae_mean=\"mean\", mae_std=\"std\", mae_count=\"count\")\n",
    "        .reset_index()\n",
    "        .sort_values([\"fn_name\", \"rate\"])\n",
    "    )\n",
    "    fn_names = sorted(agg[\"fn_name\"].unique())\n",
    "    palette = sns.color_palette(\"tab10\", len(fn_names))\n",
    "    for color, fn_name in zip(palette, fn_names):\n",
    "        d = agg[agg[\"fn_name\"] == fn_name].sort_values(\"rate\")\n",
    "        ax.plot(\n",
    "            d[\"rate\"],\n",
    "            d[\"best_tau\"],\n",
    "            label=fn_latex_alias.get(fn_name, fn_name),\n",
    "            color=color,\n",
    "            marker=\"o\"\n",
    "        )\n",
    "\n",
    "    ax.set_xscale(\"log\")\n",
    "    ax.set_xlabel(\"rate\", fontsize=16)\n",
    "    xticks = [x for x in sorted(agg[\"rate\"].unique()) if x != 4.0]\n",
    "    ax.set_xticks(xticks)\n",
    "    ax.tick_params(axis='x', labelsize=16)\n",
    "    ax.set_xticklabels([f\"{x:.1f}\" for x in xticks], rotation=45, fontsize=16)\n",
    "    return agg"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "01e879ba",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axs = plt.subplots(\n",
    "    1, 3, \n",
    "    figsize=(16, 5), \n",
    "    sharey=True, \n",
    "    layout=\"constrained\"\n",
    ")\n",
    "ax_left, ax_middle, ax_right = axs[0], axs[1], axs[2]\n",
    "\n",
    "target_method = \"GS\" \n",
    "agg = plot_best_tau_vs_rate(df_all, target_method, ax_left)\n",
    "\n",
    "ax_left.set_yscale(\"log\")\n",
    "ax_left.tick_params(axis='y', labelsize=16)\n",
    "ax_left.set_ylabel(\"best $\\\\tau$\", fontsize=16)\n",
    "ax_left.set_title(\n",
    "    f\"{method_alias.get(target_method, target_method)}\",\n",
    "    fontsize=18,\n",
    ")\n",
    "ax_left.legend(fontsize=16, loc=\"lower left\", ncol=2, columnspacing=0.5) # title=\"fn_name\", \n",
    "ax_left.grid(True, linestyle=\"--\", alpha=0.5)\n",
    "\n",
    "\n",
    "target_method = \"exp\" \n",
    "agg = plot_best_tau_vs_rate(df_all, target_method, ax_middle)\n",
    "ax_middle.set_yscale(\"log\")\n",
    "ax_middle.tick_params(axis='y', labelsize=16)\n",
    "ax_middle.set_title(\n",
    "    f\"{method_alias.get(target_method, target_method)}\",\n",
    "    fontsize=18,\n",
    ")\n",
    "ax_middle.grid(True, linestyle=\"--\", alpha=0.5)\n",
    "\n",
    "target_method = \"cubic_exp\"\n",
    "agg = plot_best_tau_vs_rate(df_all, target_method, ax_right)\n",
    "ax_right.set_yscale(\"log\")\n",
    "\n",
    "ax_right.set_title(\n",
    "    f\"{method_alias.get(target_method, target_method)}\",\n",
    "    fontsize=18,\n",
    ")\n",
    "ax_right.grid(True, linestyle=\"--\", alpha=0.5)\n",
    "fig.tight_layout()\n",
    "plt.show()\n",
    "fig.savefig(f\"Rate_vs_best_tau.pdf\", transparent=True)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "csai0",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
