{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import jax\n",
    "import jax.numpy as jnp\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import functools\n",
    "import einops\n",
    "import pandas as pd\n",
    "from scipy.stats import wasserstein_distance\n",
    "import numpy as np\n",
    "import matplotlib.ticker as tkr"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import seaborn as sns\n",
    "\n",
    "# Based on science plots retro:\n",
    "# https://github.com/garrettj403/SciencePlots/blob/master/styles/color/retro.mplstyle\n",
    "_COLORS = sns.blend_palette(\n",
    "    [\n",
    "        \"#4165c0\",\n",
    "        \"#e770a2\",\n",
    "        \"#5ac3be\",\n",
    "        \"#696969\",\n",
    "        \"#f79a1e\",\n",
    "        \"#ba7dcd\",\n",
    "    ]\n",
    ")\n",
    "\n",
    "_GRID_LINE_WIDTH = 1.0\n",
    "_STROKE_WIDTH = 2.0\n",
    "# _GRID_COLOR = \"#b0b0b0\"\n",
    "# Open Colors\n",
    "_GRID_COLOR = \"#adb5bd\"\n",
    "_TEXT_COLOR = \"#343a40\"\n",
    "_AXES_COLOR = \"#000000\"\n",
    "# _TICK_COLOR = \"#868e96\"\n",
    "_TICK_COLOR = _AXES_COLOR\n",
    "# _LABEL_COLOR = \"#868e96\"\n",
    "_LABEL_COLOR = _AXES_COLOR\n",
    "_TEXT_COLOR = _AXES_COLOR\n",
    "\n",
    "_TICK_PAD = 5\n",
    "\n",
    "# Customized Dufte style from mplx\n",
    "theme = {\n",
    "    # **mpl.rcParams,\n",
    "    # Color cycle\n",
    "    # \"axes.prop_cycle\": cycler.cycler(color=_COLORS.as_hex()),  # pyright: ignore\n",
    "    \"font.size\": 20,\n",
    "    \"font.family\": \"sans-serif\",\n",
    "    # \"font.sans-serif\": \"Helvetica\",\n",
    "    \"text.color\": _TEXT_COLOR,\n",
    "    \"axes.labelcolor\": _LABEL_COLOR,\n",
    "    \"axes.labelpad\": 18,\n",
    "    \"axes.spines.left\": True,\n",
    "    \"axes.spines.bottom\": True,\n",
    "    \"axes.spines.top\": True,\n",
    "    \"axes.spines.right\": True,\n",
    "    \"ytick.minor.left\": True,\n",
    "    # Axes aren't used in this theme, but still set some properties in case the user\n",
    "    # decides to turn them on.\n",
    "    \"axes.edgecolor\": _AXES_COLOR,\n",
    "    \"axes.linewidth\": _GRID_LINE_WIDTH,\n",
    "    # default is \"line\", i.e., below lines but above patches (bars)\n",
    "    \"axes.axisbelow\": True,\n",
    "    #\n",
    "    \"ytick.right\": False,\n",
    "    \"ytick.color\": _TICK_COLOR,\n",
    "    \"ytick.major.width\": _GRID_LINE_WIDTH,\n",
    "    \"ytick.major.pad\": _TICK_PAD,\n",
    "    \"ytick.labelsize\": 16,\n",
    "    \"xtick.minor.top\": False,\n",
    "    \"xtick.minor.bottom\": False,\n",
    "    \"xtick.color\": _TICK_COLOR,\n",
    "    \"xtick.major.width\": _GRID_LINE_WIDTH,\n",
    "    \"xtick.major.pad\": _TICK_PAD,\n",
    "    \"xtick.labelsize\": 16,\n",
    "    \"axes.grid\": True,\n",
    "    \"axes.grid.axis\": \"y\",\n",
    "    \"axes.labelsize\": 18,\n",
    "    \"grid.color\": _GRID_COLOR,\n",
    "    \"figure.constrained_layout.h_pad\": 0.1,\n",
    "    \"figure.constrained_layout.w_pad\": 0.0,\n",
    "    \"lines.linewidth\": _STROKE_WIDTH,\n",
    "    # Choose the line width such that it's very subtle, but still serves as a guide.\n",
    "    \"grid.linewidth\": _GRID_LINE_WIDTH * 0.5,\n",
    "    \"grid.alpha\": 0.5,\n",
    "    \"axes.xmargin\": 0.1,\n",
    "    \"axes.ymargin\": 0.1,\n",
    "    \"axes.titlepad\": 7.5,\n",
    "    \"axes.titlesize\": 20,\n",
    "    \"text.usetex\": True,\n",
    "    # \"pgf.texsystem\": \"xelatex\",\n",
    "    \"pgf.rcfonts\": True,\n",
    "    # \"text.latex.preamble\": r\"\\usepackage{fontspec}\\setmainfont{Arial}\",\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "@jax.jit\n",
    "def brownian_motion(rng: jax.Array, ts: jax.Array) -> jax.Array:\n",
    "    pairwise = jax.vmap(jax.vmap(jnp.minimum, in_axes=(0, None)), in_axes=(None, 0))(ts, ts)\n",
    "    return jax.random.multivariate_normal(rng, jnp.zeros_like(ts), pairwise, method='svd')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "rng = jax.random.PRNGKey(0)\n",
    "N = 2000\n",
    "data_h = 0.0001\n",
    "ts = jnp.linspace(0, data_h, N)\n",
    "dt = ts[1] - ts[0]\n",
    "sample_paths = jax.vmap(brownian_motion, in_axes=(0, None))(jax.random.split(rng, 5000), ts)\n",
    "vel = 10.0\n",
    "drift = vel * jnp.linspace(0, data_h, N)\n",
    "sample_paths += drift\n",
    "print(sample_paths.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "max_h = len(ts)\n",
    "increment = 150\n",
    "h_idx = jnp.arange(2 * increment, max_h - 3 * increment, increment)#[10:]\n",
    "H = 10\n",
    "cumulative_rewards = jnp.cumsum(sample_paths, axis=-1)\n",
    "head_returns = jnp.take_along_axis(cumulative_rewards, h_idx[None, :], axis=-1) * dt\n",
    "tail_returns = jnp.take_along_axis(sample_paths, h_idx[None, :], axis=-1) * (H - h_idx  * dt)\n",
    "return_distributions = head_returns + tail_returns\n",
    "# scales = [\n",
    "#     (\"No scale\", jnp.ones_like(h_idx)),\n",
    "#     (\"Scale by h\", h_idx * dt),\n",
    "#     (\"Scale by sqrt(h)\", jnp.sqrt(h_idx * dt)),\n",
    "#     (\"Scale by h^(1/3)\", jnp.power(h_idx * dt, 1/3)),\n",
    "#     (\"Scale by h^(1/4)\", jnp.power(h_idx * dt, 1/4))\n",
    "# ]\n",
    "processed_dists = [\n",
    "    (r\"$\\zeta^\\pi_h$\", return_distributions),\n",
    "    (\"DSUP(1)\", return_distributions / (h_idx * dt)),\n",
    "    (\"DSUP(1/2)\", return_distributions / jnp.power(h_idx * dt, 0.5)),\n",
    "    (\"DAU+DSUP(1/2)\", jnp.mean(return_distributions, axis=0) / (h_idx * dt) + return_distributions / jnp.power(h_idx * dt, 0.5)),\n",
    "    # (\"(m_cbrt(h))#zeta\", return_distributions / jnp.power(h_idx * dt, 1/3)),\n",
    "]\n",
    "# scaled_ret_dists = [(label, return_distributions / scale) for (label, scale) in scales]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = {a: np.array(b) for (a, b) in processed_dists}\n",
    "np.savez_compressed('mc_data.npz', **data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pal = sns.cubehelix_palette(len(h_idx), rot=-0.25, light=0.7)\n",
    "ridgeplots = dict()\n",
    "print(len(h_idx))\n",
    "\n",
    "sns.set_theme('paper', style=\"white\", rc={\"axes.facecolor\": (0, 0, 0, 0)})\n",
    "# sns.set_theme('paper')\n",
    "\n",
    "def label(x, color, label):\n",
    "    ax = plt.gca()\n",
    "    ax.text(-.2, .2, label, color=color,\n",
    "            ha=\"left\", va=\"center\", transform=ax.transAxes, fontsize=12)\n",
    "\n",
    "h_idx_str = [f\"${int(1 / (h * dt * 1000))}$kHz\" for h in h_idx]\n",
    "# h_idx_str = [f\"$\\\\omega = {int(1 / (h * dt * 1000))}$kHz\" for h in h_idx]\n",
    "# h_idx_str = list(\"ABCDEFGHI\")\n",
    "\n",
    "with plt.style.context(theme):\n",
    "    for (i, (tag, ret_dist_scaled)) in enumerate(processed_dists):\n",
    "        print(f\"{i + 1:>4}/{len(processed_dists)}\")\n",
    "        n_samples, _ = ret_dist_scaled.shape\n",
    "        df_data = jnp.concatenate([ret_dist_scaled[j, :] for j in range(n_samples)])\n",
    "        # df_labels = jnp.tile(h_idx, n_samples)\n",
    "        df_labels = np.tile(h_idx_str, n_samples)\n",
    "        df = pd.DataFrame(dict(data=df_data, h=df_labels))\n",
    "        # _pal = ['#03CDFC'] * len(h_idx_str)\n",
    "        # grid = sns.FacetGrid(df, row='h', hue='h', palette=_pal, aspect=25, height=0.25)\n",
    "        grid = sns.FacetGrid(df, row='h', hue='h', palette=pal, aspect=15, height=0.2)\n",
    "        for ax in grid.axes.flatten():\n",
    "            # ax.xaxis.set_major_formatter(tkr.FuncFormatter(lambda x, p: \"{:.e}\".format(x)))\n",
    "            # ax.ticklabel_format(axis='x', style='sci', scilimits=(0,2))\n",
    "            ax.xaxis.set_major_locator(plt.MaxNLocator(3)) \n",
    "        grid.map(sns.kdeplot, \"data\", clip_on=False, fill=True, alpha=1, linewidth=1.5)\n",
    "        grid.map(sns.kdeplot, \"data\", clip_on=False, color='w', lw=1)\n",
    "        # grid.refline(y=0, linewidth=0.5, linestyle='-', color='black', clip_on=False)\n",
    "        \n",
    "        if i == 2:\n",
    "            grid.map(label, \"data\")\n",
    "        grid.figure.subplots_adjust(hspace=-0.5)\n",
    "        grid.set_titles(\"\")\n",
    "        grid.set_xlabels(\"\")\n",
    "        grid.set(yticks=[], ylabel=\"\")\n",
    "        grid.despine(bottom=True, left=True)\n",
    "        # grid.figure.suptitle(tag, fontsize='small')\n",
    "        ridgeplots[tag] = grid.figure\n",
    "\n",
    "ridgeplots[list(ridgeplots.keys())[0]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for (i, key) in enumerate(ridgeplots.keys()):\n",
    "    ridgeplots[key].savefig(f\"results/centerlegend-ridgeplot-{i}.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axs = plt.subplots(1, len(processed_dists), figsize=(len(processed_dists) * 5, 5))\n",
    "k = 0.02\n",
    "palette = sns.color_palette(\"crest\", len(h_idx))[::-1]\n",
    "for (i, (tag, ret_dist_scaled)) in enumerate(processed_dists):\n",
    "    axs[i].set_title(tag)\n",
    "    for h in range(return_distributions.shape[-1])[::-1]:\n",
    "        alpha = (1 - (h / return_distributions.shape[-1])) ** 5\n",
    "        sns.kdeplot(data=ret_dist_scaled[:, h], ax=axs[i], color=palette[h])\n",
    "        \n",
    "fig.tight_layout()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axs = plt.subplots(2, len(processed_dists), figsize=(len(processed_dists) * 5, 10), sharex=True)\n",
    "for (i, (tag, ret_dist_scaled)) in enumerate(processed_dists):\n",
    "    axs[0, i].set_title(tag)\n",
    "    means = jnp.abs(jnp.mean(ret_dist_scaled, axis=0))\n",
    "    w1_distances = jnp.array([wasserstein_distance(eta, jnp.array([0])) for eta in ret_dist_scaled.T])\n",
    "    axs[1, i].set_xlabel(\"h\")\n",
    "    axs[0, i].plot(means)\n",
    "    axs[1, i].plot(w1_distances)\n",
    "axs[0, 0].set_ylabel(\"Mean action gap\")\n",
    "axs[1, 0].set_ylabel(\"W1 action gap\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
