{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7802dceb",
   "metadata": {},
   "outputs": [],
   "source": [
    "import argparse\n",
    "import os\n",
    "import sys\n",
    "from itertools import product\n",
    "import time\n",
    "from pathlib import Path\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from scipy import stats\n",
    "import torch\n",
    "from torch import nn\n",
    "import torch.nn.functional as F\n",
    "import seaborn as sns\n",
    "from datetime import datetime\n",
    "from tqdm import trange\n",
    "from scipy.stats import poisson, nbinom\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9cd0d6a4",
   "metadata": {},
   "outputs": [],
   "source": [
    "def f1(z, rate = None):\n",
    "    if not torch.is_tensor(z):\n",
    "        z = torch.tensor(z)\n",
    "    return z\n",
    "    \n",
    "def f2(z, rate = None):\n",
    "    if not torch.is_tensor(z):\n",
    "        z = torch.tensor(z)\n",
    "    # modify to make it can apply to float\n",
    "    return (z + 1e-9)**0.5\n",
    "\n",
    "def f3(z, rate = None):\n",
    "    if not torch.is_tensor(z):\n",
    "        z = torch.tensor(z)\n",
    "    return z**2\n",
    "\n",
    "def f4(z, rate = None):\n",
    "    if not torch.is_tensor(z):\n",
    "        z = torch.tensor(z)\n",
    "    return z**1.5 - 2 * z\n",
    "\n",
    "def f5(z, rate = None):\n",
    "    if not torch.is_tensor(z):\n",
    "        z = torch.tensor(z)\n",
    "    return torch.cos(z) ** 2\n",
    "\n",
    "def f6(z, rate = None):\n",
    "    if not torch.is_tensor(z):\n",
    "        z = torch.tensor(z)\n",
    "    return torch.sigmoid(z)\n",
    "\n",
    "def f7(z, rate = None): \n",
    "    if not torch.is_tensor(z):\n",
    "        z = torch.tensor(z)\n",
    "    return z**2/(rate + 1e-6)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f17ba50c",
   "metadata": {},
   "outputs": [],
   "source": [
    "def sample_poisson(\n",
    "    rate,\n",
    "    method: str = \"score\",\n",
    "    n_monte_carlo: int = 1,\n",
    "    upperbound: int = 8,\n",
    "    tau: float = 0.1,\n",
    "):\n",
    "    # rate: Tensor of shape (,)\n",
    "    if method == \"score\":\n",
    "        return torch.poisson(\n",
    "            rate.unsqueeze(0).expand(\n",
    "                n_monte_carlo,\n",
    "            )\n",
    "        )\n",
    "    elif method == \"GS\":\n",
    "        k = torch.arange(upperbound, dtype=torch.float64) # .float()  # [0, 1, ..., upperbound-1]\n",
    "        logit_pi = k * rate.log().unsqueeze(-1) - torch.lgamma(\n",
    "            k + 1\n",
    "        )  # [logit_pi_0, logit_pi_1, ..., logit_pi_(upperbound-1)]\n",
    "        # logit_pi = (\n",
    "        #     k * rate.log().unsqueeze(-1) - rate - torch.lgamma(k + 1)\n",
    "        # )  # [logit_pi_0, logit_pi_1, ..., logit_pi_(upperbound-1)]\n",
    "        z_gs_samples = F.gumbel_softmax(\n",
    "            logit_pi.unsqueeze(0).expand(n_monte_carlo, upperbound), tau=tau, hard=False\n",
    "        )  # one-hot [z_gs_0, z_gs_1, ..., z_gs_(upperbound-1)]\n",
    "        z = z_gs_samples @ k  # z = sum_{k=0}^{upperbound-1} z_gs_k * k\n",
    "        return z\n",
    "    \n",
    "    elif method == \"exp\":\n",
    "        u = torch.rand((n_monte_carlo, upperbound))\n",
    "        z_exp_samples = -(1 - u).log() / rate\n",
    "        z = (torch.sigmoid((1 - torch.cumsum(z_exp_samples, dim=-1)) / tau)).sum(dim=-1) #  (n_monte_carlo,)\n",
    "        return z\n",
    "    \n",
    "    elif method == \"cubic_exp\":\n",
    "        u = torch.rand((n_monte_carlo, upperbound))\n",
    "        z_exp_samples = -(1 - u).log() / rate\n",
    "        times = torch.cumsum(z_exp_samples, dim=-1)\n",
    "        logits = (1 - times) / tau\n",
    "        u_cubic = torch.clamp(0.5 * logits + 0.5, min=0.0, max=1.0)\n",
    "        indicator = 3 * u_cubic.pow(2) - 2 * u_cubic.pow(3)\n",
    "        z = indicator.sum(dim=-1)  #  (n_monte_carlo,)\n",
    "        return z\n",
    "\n",
    "\n",
    "def poisson_log_prob(rate, z):\n",
    "    return z * rate.log() - rate - torch.lgamma(z + 1)\n",
    "\n",
    "\n",
    "def compute_upperbound(rate: float, percentile: float = 1e-2, r: float = None, p: float = None):\n",
    "    assert rate > 0.0, f\"rate must be positive, got: {rate}\"\n",
    "    pois = stats.poisson(rate)\n",
    "    n_exp = pois.ppf(1.0 - percentile)\n",
    "    return max(int(n_exp), 3)\n",
    "\n",
    "\n",
    "def compute_exact_grad(f_fn: callable, rate: float= None, r:float = None, p:float = None):\n",
    "    assert rate is not None and r is None and p is None\n",
    "    upperbound = int(rate.item() + 20)\n",
    "    z = torch.arange(upperbound + 1, dtype=torch.float64)\n",
    "    f_z = f_fn(z, rate)\n",
    "    lam = rate.detach().clone().to(torch.float64)\n",
    "    pmf = torch.exp(poisson_log_prob(lam, z))\n",
    "    exp_term = (f_z * pmf * (z / lam - 1)).sum()\n",
    "    return exp_term.item()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "60082e2e",
   "metadata": {},
   "source": [
    "## Distribution Demo"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4760a54e",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.manual_seed(0)\n",
    "\n",
    "fig, ax = plt.subplots()\n",
    "\n",
    "rate = torch.tensor(2).to(torch.float64)\n",
    "\n",
    "for method, color in zip(\n",
    "    [\"score\", \"GS\", \"exp\", \"cubic_exp\"], # \n",
    "    [\"C0\", \"C1\", \"C2\", \"C3\"],\n",
    "):\n",
    "    z = sample_poisson(\n",
    "        rate=rate, method=method, n_monte_carlo=1000, tau=0.1, upperbound=8\n",
    "    )\n",
    "\n",
    "    bins = np.arange(-0.05, 5, 0.1)\n",
    "    ax.hist(\n",
    "        z.numpy(),\n",
    "        density=True,\n",
    "        bins=bins,\n",
    "        color=color,\n",
    "        label=method,\n",
    "        fill=False,\n",
    "        histtype=\"step\",\n",
    "        alpha=0.9,\n",
    "    )\n",
    "\n",
    "k = np.arange(0, 9)\n",
    "pmf = poisson(mu=rate.item()).pmf(k)\n",
    "ax.plot(k, pmf / 0.1, \"ko\", label=\"True value\", markersize=2)\n",
    "\n",
    "ax.legend()\n",
    "ax.set(\n",
    "    xlabel=\"z\",\n",
    "    ylabel=\"density\",\n",
    "    title=f\"Poisson Sampling Methods Comparison (rate={rate.item()})\",\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a66070fc",
   "metadata": {},
   "source": [
    "## Adaptive upperbound - Possion"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6bb50487",
   "metadata": {},
   "outputs": [],
   "source": [
    "# percentile = 1e-3\n",
    "percentile = 1e-2\n",
    "# to_csv = False\n",
    "f_map = {\n",
    "    \"z\": f1,\n",
    "    \"z^0.5\": f2,\n",
    "    \"z^2\": f3,\n",
    "    \"z^1.5-2z\": f4,\n",
    "    \"cos^2(z)\": f5,\n",
    "    \"sigmoid(z)\": f6,\n",
    "    \"z^2(rate)^-1\": f7,\n",
    "}\n",
    "# rate_list = [0.1, 0.2, 0.5, 1.0, 2.0, 5.0, 10.0, 20.0, 50.0] # \n",
    "# rate_list = [0.1, 0.5, 1.0, 5.0]\n",
    "# method_list = [\"score\", \"GS\", \"exp\", \"cubic_exp\"]\n",
    "# tau_list = [0.01, 0.02, 0.05, 0.1, 0.2, 0.5, 1.0, 2.0, 5.0]\n",
    "n_monte_carlo_list = [1, 2, 5, 10, 20, 50, 100, 200, 500]\n",
    "seed_list = np.arange(10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7a700609",
   "metadata": {},
   "outputs": [],
   "source": [
    "to_csv = True\n",
    "tau_list = (\n",
    "    # list(np.arange(0.001, 0.00201, 0.0001))\n",
    "    # + list(np.arange(0.002, 0.00501, 0.0005))\n",
    "    # + list(np.arange(0.005, 0.01001, 0.001))\n",
    "    list(np.arange(0.01, 0.02001, 0.001))\n",
    "    + list(np.arange(0.02, 0.0501, 0.001))\n",
    "    + list(np.arange(0.05, 0.1001, 0.005))\n",
    "    + list(np.arange(0.1, 0.2001, 0.005))\n",
    "    + list(np.arange(0.2, 0.5001, 0.01))\n",
    "    + list(np.arange(0.5, 1.0001, 0.01))\n",
    "    # + list(np.arange(0.5, 1.0001, 0.1)) ###\n",
    "    # + list(np.arange(1.0, 2.0001, 0.1))\n",
    "    # + list(np.arange(2.0, 5.0001, 0.5))\n",
    "    # + list(np.arange(5.0, 10.0001, 1.0))\n",
    "    # + list(np.arange(10.0, 20.0001, 1.0))\n",
    "    # + list(np.arange(20.0, 50.0001, 5.0))\n",
    "    # + list(np.arange(50.0, 100.0001, 10.0))\n",
    "    # + list(np.arange(100.0, 500.0001, 50.0))\n",
    "    # + list(np.arange(500.0, 1000.0001, 100.0))\n",
    ")\n",
    "tau_list = sorted({round(x, 6) for x in tau_list})\n",
    "\n",
    "# rate_list = (\n",
    "#     list(np.arange(0.1, 0.2001, 0.01))\n",
    "#     + list(np.arange(0.2, 1.0001, 0.05))\n",
    "#     + list(np.arange(1.0, 2.0001, 0.1))\n",
    "#     + list(np.arange(2.0, 10.0001, 0.5))\n",
    "# )\n",
    "# rate_list = sorted({round(x, 6) for x in rate_list})\n",
    "rate_list = [0.1, 0.2, 0.5, 1.0, 2.0, 3.0, 4.0, 5.0, 10.0, 20.0] # \n",
    "method_list = [\"GS\", \"exp\", \"cubic_exp\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "391b973b",
   "metadata": {},
   "outputs": [],
   "source": [
    "for fname, f in f_map.items():\n",
    "    df_result = []\n",
    "\n",
    "    for rate, method, tau, n_monte_carlo, seed in product(\n",
    "        rate_list,\n",
    "        method_list,\n",
    "        tau_list,\n",
    "        n_monte_carlo_list,\n",
    "        seed_list,\n",
    "    ):\n",
    "        if method == \"score\" and (tau != tau_list[0]):\n",
    "            continue\n",
    "        torch.manual_seed(seed)\n",
    "        rate_tensor = torch.tensor(rate, requires_grad=True)\n",
    "        start_time = time.time()\n",
    "        upperbound = compute_upperbound(rate, percentile=percentile)\n",
    "        # print(f\"rate={rate}, method={method}, tau={tau}, n_mc={n_monte_carlo}, upperbound={upperbound}\")\n",
    "        z_samples = sample_poisson(\n",
    "            rate_tensor,\n",
    "            method=method,\n",
    "            n_monte_carlo=n_monte_carlo,\n",
    "            upperbound = upperbound, \n",
    "            tau=tau,\n",
    "        )\n",
    "\n",
    "        # if method in [\"GS\"] and seed == 0: #, \"exp\"\n",
    "        #     z_cpu = z_samples.detach().cpu()\n",
    "        #     head = z_cpu[: min(5, z_cpu.numel())]\n",
    "        #     print(\n",
    "        #         f\"method={method}, rate={rate}, tau={tau}, n_mc={n_monte_carlo} \"\n",
    "        #         f\"z={z_cpu.tolist()}\"\n",
    "        #     )\n",
    "\n",
    "        f_z = f(z_samples, rate=rate)\n",
    "        if method == \"score\":\n",
    "            loss = (f_z.detach() * poisson_log_prob(rate_tensor, z_samples)).mean()\n",
    "        else:\n",
    "            loss = f_z.mean()\n",
    "        loss.backward()\n",
    "        end_time = time.time()\n",
    "        df_result.append(\n",
    "            {\n",
    "                \"rate\": rate,\n",
    "                \"method\": method,\n",
    "                \"$\\\\tau$\": tau,\n",
    "                \"n_monte_carlo\": n_monte_carlo,\n",
    "                \"seed\": seed,\n",
    "                \"time\": end_time - start_time,\n",
    "                \"grad_estimate\": rate_tensor.grad.item(),\n",
    "            }\n",
    "        )\n",
    "        grad_estimate = rate_tensor.grad.item()\n",
    "    df_result_poisson = pd.DataFrame(df_result)\n",
    "\n",
    "    if to_csv:\n",
    "        timestamp = time.strftime(\"%Y%m%d_%H%M%S\")\n",
    "        base = Path(r\"your dir\")\n",
    "        output_path = base / f\"df_result_poisson_{fname}_{timestamp}.csv\"\n",
    "        df_result_poisson.to_csv(output_path, index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fc76f770",
   "metadata": {},
   "outputs": [],
   "source": [
    "from pathlib import Path\n",
    "timestamp = time.strftime(\"%Y%m%d_%H%M%S\")\n",
    "base = Path(r\"your dir\")\n",
    "output_path = base / f\"df_result_poisson_sigmoid_{timestamp}.csv\"\n",
    "df_result_poisson.to_csv(output_path, index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "015070a7",
   "metadata": {},
   "outputs": [],
   "source": [
    "base = Path(r\"your dir\")\n",
    "latest = sorted(base.glob(\"df_result_poisson_*.csv\"))[-1]\n",
    "print(f\"Loading: {latest}\")\n",
    "df_loaded = pd.read_csv(latest)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dcfde867",
   "metadata": {},
   "source": [
    "### Method Result Comparasion"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "327059d4",
   "metadata": {},
   "outputs": [],
   "source": [
    "f = f7"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ba933be0",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axs = plt.subplots(\n",
    "    len(tau_list),\n",
    "    len(rate_list),\n",
    "    figsize=(16, 16),\n",
    "    # figsize=(50, 50),\n",
    "    layout=\"constrained\",\n",
    "    sharex=True,\n",
    "    sharey=\"col\",\n",
    ")\n",
    "for i, tau in enumerate(tau_list):\n",
    "    for j, rate in enumerate(rate_list):\n",
    "        y_true=compute_exact_grad(\n",
    "            f, torch.tensor(rate, dtype=torch.float64)\n",
    "        )\n",
    "        ax = axs[i, j]\n",
    "        df_plot = df_result_poisson[\n",
    "            ((df_result_poisson[\"$\\\\tau$\"] == tau) | (df_result_poisson[\"method\"] == \"score\"))\n",
    "            & (df_result_poisson[\"rate\"] == rate)\n",
    "        ]\n",
    "        sns.lineplot(\n",
    "            data=df_plot,\n",
    "            x=\"n_monte_carlo\",\n",
    "            y=\"grad_estimate\",\n",
    "            hue=\"method\",\n",
    "            hue_order=[\"GS\", \"exp\", \"cubic_exp\"], # , \"score\"\n",
    "            ax=ax,\n",
    "        )\n",
    "        ax.axhline(\n",
    "            y=y_true, color=\"k\", linestyle=\"--\", label=f\"True Grad\\n= {y_true:.3f}\"\n",
    "        )\n",
    "        ax.set_title(f\"rate={rate}, $\\\\tau$={tau}\")\n",
    "        ax.set_xlabel(\"Number of Monte Carlo Samples\")\n",
    "        ax.set_ylabel(\"Gradient Estimate\")\n",
    "        ax.get_legend().remove()\n",
    "        ax.set_xscale(\"log\")\n",
    "        ax.set_yscale(\"linear\")\n",
    "        ax.legend()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "21d02e12",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "csai",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.23"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
