{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import scienceplots\n",
    "plt.style.use('science')\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "from volatility_smoothing.utils.test import spread_error\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.rcParams['savefig.dpi'] = 300\n",
    "plt.rcParams['savefig.bbox'] = 'tight'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "resources = None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_path = \"../eval/store/9457504 (finetune)\"\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_files = [f\"{data_path}/{filename}\" for filename in os.listdir(data_path)]\n",
    "data = torch.load(data_files[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_files = [file for file in data_files if int(file.split('_')[-1][:4]) <= 2020]\n",
    "test_files = [file for file in data_files if int(file.split('_')[-1][:4]) > 2020]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Plotting Helpers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x_range = (0.1, 1.)\n",
    "y_range = (-1.5, .5)\n",
    "\n",
    "def spatial_average(x, y, feature, x_range=x_range, y_range=y_range,\n",
    "                    x_bins = 20, y_bins=20):\n",
    "\n",
    "    hist, _, _ = np.histogram2d(x.squeeze(), y.squeeze(), bins=(x_bins, y_bins), range=(x_range, y_range), weights=feature.squeeze())\n",
    "    counts, _, _ = np.histogram2d(x.squeeze(), y.squeeze(), bins=(x_bins, y_bins), range=(x_range, y_range))\n",
    "\n",
    "    return hist / counts"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_heatmap(files, feature_fun):\n",
    "    avgs = []\n",
    "    for file in files:\n",
    "        data = torch.load(file)\n",
    "\n",
    "        x, y, feature = feature_fun(data)\n",
    "        avg = spatial_average(x, y, feature)\n",
    "        avgs.append(avg)\n",
    "\n",
    "    out = np.nanmean(np.stack(avgs), axis=0)\n",
    "    return out\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.colors as colors\n",
    "import matplotlib.ticker as ticker\n",
    "\n",
    "norm = colors.TwoSlopeNorm(vmin=-0.1, vcenter=0.01, vmax=0.1)\n",
    "\n",
    "\n",
    "aspect = 2.\n",
    "extent = (*y_range, *x_range)\n",
    "fraction = 0.045\n",
    "pad = 0.05\n",
    "\n",
    "\n",
    "def plot_error(error, ax=None, fig=None, vmin=None, vmax=None):\n",
    "\n",
    "    if ax is None:\n",
    "        fig, ax = plt.subplots()\n",
    "\n",
    "    im = ax.imshow(error, origin='lower', cmap='PuRd', extent=extent, aspect=aspect, vmin=vmin, vmax=vmax) #norm=norm)\n",
    "    ax.set_xlabel(r'$z$')\n",
    "    ax.set_ylabel(r'$\\rho$')\n",
    "    ax.set_xticks([-1.5, -1., -0.5, 0., 0.5])\n",
    "\n",
    "    return im\n",
    "\n",
    "\n",
    "def plot_arb_constraint(arb, ax=None, fig=None, eps=0.001):\n",
    "    \n",
    "    if ax is None:\n",
    "        fig, ax = plt.subplots()\n",
    "\n",
    "    vmax = arb.max()\n",
    "    \n",
    "    norm = colors.TwoSlopeNorm(vmin=0, vcenter=eps, vmax=vmax)\n",
    "    cax = ax.imshow(arb, origin='lower', cmap='RdYlBu', norm=norm, aspect=aspect, extent=extent)\n",
    "    cbar = fig.colorbar(cax, ax=ax, fraction=fraction, pad=pad)\n",
    "    cbar.set_ticks([0., eps], labels=[\"0\", r\"$\\varepsilon$\"]) #, f\"{vmax:.1f}\"])\n",
    "    cbar.ax.minorticks_off()\n",
    "\n",
    "    ax.set_xlabel(r'$z$')\n",
    "    ax.set_ylabel(r'$\\rho$')\n",
    "    ax.set_xticks([-1.5, -1., -0.5, 0., 0.5])\n",
    "\n",
    "    return cax"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Errors"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def mape_fun(data):\n",
    "    x = data['r'].squeeze()\n",
    "    y = data['z'].squeeze()\n",
    "    iv_predict = data['iv_predict'].squeeze()\n",
    "    iv_target = data['implied_volatility'].squeeze()\n",
    "    return x, y, np.abs((iv_predict - iv_target) / iv_target)\n",
    "\n",
    "def spread_fun(data): \n",
    "    x = data['r'].squeeze()\n",
    "    y = data['z'].squeeze()\n",
    "    return x, y, spread_error(data['iv_predict'], data)\n",
    "\n",
    "\n",
    "abs_error_train = create_heatmap(train_files, mape_fun)\n",
    "abs_error_test = create_heatmap(test_files[:1000], mape_fun)\n",
    "spread_error_train = create_heatmap(train_files, spread_fun)\n",
    "spread_error_test = create_heatmap(test_files[:1000], spread_fun)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Arbitrage"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def but_fun(data, eps=0.):\n",
    "    x = data['grid']['r'].flatten()\n",
    "    y = data['grid']['z'].flatten()\n",
    "\n",
    "    but_error = F.relu(-data['butterfly_error'].flatten() - eps)\n",
    "    but_error = data['butterfly_error'].flatten()\n",
    "    return x, y, but_error\n",
    "\n",
    "def cal_fun(data):\n",
    "    x = data['grid']['r'][1:].flatten()\n",
    "    y = data['grid']['z'][1:].flatten()\n",
    "    return x, y, data['calendar_error'].flatten()\n",
    "\n",
    "\n",
    "but_error_train = create_heatmap(train_files[:1000], but_fun)\n",
    "but_error_test = create_heatmap(test_files[:1000], but_fun)\n",
    "cal_error_train = create_heatmap(train_files[:1000], cal_fun)\n",
    "cal_error_test = create_heatmap(test_files[:1000], cal_fun)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Individual Plots"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "figsize = (1.8, 1.8)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "props = dict(boxstyle='round', facecolor='white', alpha=1.)\n",
    "\n",
    "# place a text box in upper left in axes coords\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axes = plt.subplots(1, 4, figsize=(10, 2), sharey=True)\n",
    "\n",
    "im = plot_error(abs_error_test, ax=(ax := axes[0]), fig=fig, vmax=0.05)\n",
    "fig.colorbar(im, fraction=fraction, pad=pad)\n",
    "ax.set_title(r\"$\\delta_\\text{abs}$\")\n",
    "\n",
    "im = plot_error(spread_error_test, ax=(ax := axes[1]), fig=fig, vmax=5.)\n",
    "fig.colorbar(im, fraction=fraction, pad=pad)\n",
    "ax.set_ylabel(None)\n",
    "ax.set_title(r'$\\delta_\\text{spr}$')\n",
    "\n",
    "\n",
    "im = plot_arb_constraint(but_error_test, ax=(ax := axes[2]), fig=fig)\n",
    "axes[2].set_ylabel(None)\n",
    "ax.set_title(r\"$\\text{But}(\\cdot, \\hat{v}, \\partial_k \\hat{v}, \\partial_k^2 \\hat{v})$\")\n",
    "#ax.text(0.05, 0.95, r\"$\\text{But}(\\hat{v})$\", transform=ax.transAxes, fontsize=12,\n",
    "#            verticalalignment='top', bbox=props)\n",
    "\n",
    "im = plot_arb_constraint(cal_error_test, ax=(ax := axes[3]), fig=fig)\n",
    "ax.set_ylabel(None)\n",
    "ax.set_title(r\"$\\partial_\\tau [\\hat{v}\\sqrt{\\tau}]$\")\n",
    "#ax.text(0.05, 0.95, r\"$\\partial_\\tau [\\hat{v}\\sqrt{\\tau}]$\", transform=ax.transAxes, fontsize=12,\n",
    "#            verticalalignment='top', bbox=props)\n",
    "\n",
    "fig.savefig(f\"{resources}/spatial_metrics.png\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(figsize=figsize)\n",
    "im = plot_error(abs_error_test, ax=ax, fig=fig)\n",
    "fig.colorbar(im, fraction=fraction, pad=pad)\n",
    "fig.savefig(f\"{resources}/absolute_error_domain.png\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(figsize=figsize)\n",
    "im = plot_error(spread_error_test, ax=ax, fig=fig)\n",
    "fig.colorbar(im, fraction=fraction, pad=pad)\n",
    "fig.savefig(f\"{resources}/spread_error_domain.png\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(figsize=figsize)\n",
    "im = plot_arb_constraint(but_error_test, ax=ax, fig=fig)\n",
    "fig.savefig(f\"{resources}/strike_arbitrage_domain.png\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(figsize=figsize)\n",
    "im = plot_arb_constraint(cal_error_test, ax=ax, fig=fig)\n",
    "fig.savefig(f\"{resources}/cal_arbitrage_domain.png\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x_range = (0.01, 1)\n",
    "y_range = (-1.5, 0.5)\n",
    "\n",
    "def create_heatmap_plot(a_hist, b_hist, extent = (*y_range, *x_range)):\n",
    "\n",
    "    vmin = min(a_hist.min(), b_hist.min())\n",
    "    vmax = max(a_hist.max(), b_hist.max())\n",
    "\n",
    "    fig, axes = plt.subplots(1, 2, figsize=(4, 2), sharey=True)\n",
    "\n",
    "    im = plot_error(a_hist, ax=axes[0], fig=fig, vmin=vmin, vmax=vmax)\n",
    "    im = plot_error(b_hist, ax=axes[1], fig=fig, vmin=vmin, vmax=vmax)\n",
    "\n",
    "\n",
    "    #for lh in legend.legendHandles:\n",
    "    #    lh.set_alpha(1.0)\n",
    "\n",
    "    # Shared colorbar\n",
    "    fig.subplots_adjust(right=0.9)\n",
    "    cbar_ax = fig.add_axes([0.95, 0.15, 0.02, 0.7])\n",
    "    fig.colorbar(im, cax=cbar_ax)\n",
    "\n",
    "    return fig, axes\n",
    "\n",
    "#fig.savefig(f\"{resources}/trade_volume.png\", dpi=300, bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axes = create_heatmap_plot(spread_error_train, spread_error_test)\n",
    "#fig.savefig(f\"{resources}/spread_error_spatial_raw.png\", dpi=300, bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "create_heatmap_plot(abs_error_train, abs_error_test)\n",
    "#fig.savefig(f\"{resources}/abs_error_spatial_raw.png\", dpi=300, bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Arbitrage"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def but_fun(data, eps=0.):\n",
    "    x = data['grid']['r'].flatten()\n",
    "    y = data['grid']['z'].flatten()\n",
    "\n",
    "    but_error = F.relu(-data['butterfly_error'].flatten() - eps)\n",
    "    but_error = data['butterfly_error'].flatten()\n",
    "    return x, y, but_error\n",
    "\n",
    "def cal_fun(data):\n",
    "    x = data['grid']['r'][1:].flatten()\n",
    "    y = data['grid']['z'][1:].flatten()\n",
    "    return x, y, data['calendar_error'].flatten()\n",
    "\n",
    "\n",
    "but_error_train = create_heatmap(train_files[:100], but_fun)\n",
    "but_error_test = create_heatmap(test_files[:100], but_fun)\n",
    "cal_error_train = create_heatmap(train_files[:100], cal_fun)\n",
    "cal_error_test = create_heatmap(test_files[:100], cal_fun)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axes = plt.subplots(1, 2, figsize=(5, 2), sharey=True)\n",
    "\n",
    "plot_arb_constraint(but_error_test, ax=axes[1], fig=fig)\n",
    "plot_arb_constraint(cal_error_test, ax=axes[0], fig=fig)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_arb_constraint(cal_error_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "norm = colors.TwoSlopeNorm(vmin=0, vcenter=0.001, vmax=but_error_test.max())\n",
    "plt.imshow(but_error_test, origin='lower', cmap=plt.cm.coolwarm_r, norm=norm)\n",
    "cbar = plt.colorbar()\n",
    "cbar.set_ticks([0, 0.01, 1], labels=[r\"$0$\", r\"$\\varepsilon$\", \"$1$\"])\n",
    "#cbar.set_ticks(locator)\n",
    "#cbar.update_ticks()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.imshow(cal_error_test, origin='lower', cmap=plt.cm.coolwarm_r, norm=norm)\n",
    "plt.colorbar()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "op-ds-cqZ6S183-py3.11",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
