{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import scienceplots\n",
    "plt.style.use('science')\n",
    "import torch\n",
    "\n",
    "from op_ds.gno.gno import GNOLayer, GNO\n",
    "from op_ds.gno.kernel import NonlinearKernelTransformWithSkip\n",
    "from op_ds.utils.fnn import FNN\n",
    "from volatility_smoothing.utils.gno.train import Trainer\n",
    "from volatility_smoothing.utils.options_data import WRDSOptionsDataset\n",
    "from volatility_smoothing.utils.gno.dataset import GNOOptionsDataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.rcParams['savefig.dpi'] = 300\n",
    "plt.rcParams['savefig.bbox'] = 'tight'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "resources = None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [],
   "source": [
    "spx_dir = \"../data/wrds/spx2018\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "spx_dataset = WRDSOptionsDataset(spx_dir, return_as='pt')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "from volatility_smoothing.utils.gno.dataset import GNOOptionsDataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "gno_dataset = GNOOptionsDataset(spx_dataset, subsample=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Model and Optimizer"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Instantiate"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "in_channels = 1\n",
    "out_channels = 1\n",
    "channels = (in_channels, 16, 16, 16, out_channels)\n",
    "spatial_dim = 2\n",
    "gno_channels = 16\n",
    "hidden_channels = 64\n",
    "\n",
    "gno_layers = []\n",
    "\n",
    "for i in range(m := (len(channels) - 1)):\n",
    "    lifting = FNN.from_config((channels[i], hidden_channels, gno_channels), hidden_activation='gelu', batch_norm=False)\n",
    "    projection = None if i < m - 1 else FNN.from_config((gno_channels, hidden_channels, channels[i+1]), hidden_activation='gelu', batch_norm=False)\n",
    "    transform = NonlinearKernelTransformWithSkip(in_channels=gno_channels, out_channels=gno_channels, skip_channels=in_channels, spatial_dim=spatial_dim, hidden_channels=(hidden_channels, hidden_channels), hidden_activation='gelu', batch_norm=False)\n",
    "\n",
    "    if i == 0:\n",
    "        local_linear = False\n",
    "    else:\n",
    "        local_linear = True\n",
    "        \n",
    "    activation = torch.nn.GELU() if i < m - 1 else torch.nn.Softplus(beta=0.5)\n",
    "        \n",
    "    gno_layer = GNOLayer(gno_channels, transform=transform, local_linear=local_linear, local_bias=True,\n",
    "                         activation=activation, lifting=lifting, projection=projection)\n",
    "    gno_layers.append(gno_layer)\n",
    "    \n",
    "gno = GNO(*gno_layers, in_channels=in_channels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "optimizer = torch.optim.AdamW(gno.parameters())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load Checkpoint"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_checkpoint(model, optimizer, path):\n",
    "    checkpoint = torch.load(path, map_location=device)\n",
    "    model.load_state_dict(checkpoint['model'])\n",
    "    optimizer.load_state_dict(checkpoint['optimizer'])\n",
    "    # logger.info(f\"Loaded checkpoint from {path}\")\n",
    "    return model, optimizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "path = \"../train/store/9448705/checkpoints/checkpoint_final.pt\"\n",
    "gno, optimizer = load_checkpoint(gno, optimizer, path)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "step_r = 0.05\n",
    "step_z = 0.01\n",
    "subsample_size = 50\n",
    "radius = 0.3\n",
    "trainer = Trainer(step_r=step_r, step_z=step_z)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Evaluate Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "from volatility_smoothing.utils.slice_data import slice_data\n",
    "from volatility_smoothing.utils.gno.edge_index import generate_edge_index\n",
    "\n",
    "\n",
    "def generate_split_mask(n: int, q: float = 0.5):\n",
    "    num_elements_to_drop = int(q * n)\n",
    "    drop_indices = torch.randperm(n)[:num_elements_to_drop]\n",
    "    mask = torch.ones(n, dtype=torch.bool)\n",
    "    mask[drop_indices] = False\n",
    "    return mask\n",
    "\n",
    "\n",
    "def split_slice(slice: torch.Tensor, q: float=0.5, extrapolate: bool = False):\n",
    "    # slice must be [n_channels, n_samples]\n",
    "    z = slice[1]\n",
    "    mask = generate_split_mask(z.shape[0], q=q)\n",
    "    if extrapolate:\n",
    "        moneyness_mask = (-1.3 <= z) & (z <= 0.3)\n",
    "        mask = moneyness_mask & mask\n",
    "\n",
    "    return slice[..., mask], slice[..., ~mask]\n",
    "\n",
    "\n",
    "def compute_test_error_2(data, **kwargs):\n",
    "    r_uniques, slices = slice_data(data.r, data.r, data.z, data.implied_volatility)\n",
    "\n",
    "    train, test = map(lambda l: torch.concatenate(l, dim=1), zip(*[split_slice(slice, **kwargs) for slice in slices]))\n",
    "    pos_train, vol_train = train.T.split((2, 1), dim=1)\n",
    "    pos_test, vol_test = test.T.split((2, 1), dim=1)\n",
    "\n",
    "    edge_index = generate_edge_index(pos_train, pos_test, subsample_size=subsample_size, radius=radius)\n",
    "    with torch.no_grad():\n",
    "        vol_train_predict, vol_test_predict = gno(x=vol_train, pos_x=pos_train, pos_y=pos_test, edge_index=edge_index)\n",
    "\n",
    "    error_train = torch.abs((vol_train_predict - vol_train) / vol_train)\n",
    "    error_test = torch.abs((vol_test_predict - vol_test) / vol_test)\n",
    "\n",
    "    q = torch.tensor([0.05, 0.5, 0.95])\n",
    "    q_train= torch.concatenate((error_train.mean().unsqueeze(0), torch.quantile(error_train, q=q)))\n",
    "    q_test = torch.concatenate((error_test.mean().unsqueeze(0), torch.quantile(error_test, q=q)))\n",
    "\n",
    "    return q_train, q_test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "compute_test_error_2(gno_dataset[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "def average_error_2(data, n: int = 10, **kwargs):\n",
    "    q_train, q_test = map(torch.stack, zip(*[compute_test_error_2(data, **kwargs) for _ in range(n)]))\n",
    "    return q_train.mean(dim=0), q_test.mean(dim=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "q_train_list = {}\n",
    "q_test_list = {}\n",
    "for data in gno_dataset:\n",
    "    q_train, q_test = average_error_2(data, n=25\n",
    "                                      , q=0.5, extrapolate=False)\n",
    "    q_train_list[data.quote_datetime] = q_train.tolist()\n",
    "    q_test_list[data.quote_datetime] = q_test.tolist()\n",
    "\n",
    "print(pd.DataFrame(q_train_list).T.describe(percentiles=(0.05, 0.5, 0.95))[0]) \n",
    "print(pd.DataFrame(q_test_list).T.describe(percentiles=(0.05, 0.5, 0.95))[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "q_train_list = {}\n",
    "q_test_list = {}\n",
    "for data in gno_dataset:\n",
    "    q_train, q_test = average_error_2(data, n=25, q=0.6, extrapolate=True)\n",
    "    q_train_list[data.quote_datetime] = q_train.tolist()\n",
    "    q_test_list[data.quote_datetime] = q_test.tolist()\n",
    "\n",
    "print(pd.DataFrame(q_train_list).T.describe(percentiles=(0.05, 0.5, 0.95))[0]) \n",
    "print(pd.DataFrame(q_test_list).T.describe(percentiles=(0.05, 0.5, 0.95))[0])"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "operator-deep-smoothing-for-implied-volati-Lc0MA9F8-py3.11",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
