{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# GNO Training for Volatility Smoothing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os, logging\n",
    "import json\n",
    "import subprocess\n",
    "from pathlib import Path\n",
    "import fnmatch\n",
    "from random import shuffle"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Store Directory"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "storedir = None  # Set this to save checkpoints"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "if storedir is not None:\n",
    "    checkpoint_storedir = f\"{storedir}/checkpoints\"\n",
    "    Path(checkpoint_storedir).mkdir(exist_ok=True)\n",
    "else:\n",
    "    checkpoint_storedir = None\n",
    "    \n",
    "try:\n",
    "    job_id = os.environ['PBS_JOBID'].split('.pbs')[0]\n",
    "except KeyError:\n",
    "    job_id = 'local'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "logging.basicConfig()\n",
    "logger = logging.getLogger('job')\n",
    "logger.setLevel(logging.INFO)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "logger.info('Importing third-party packages ...')\n",
    "\n",
    "import torch\n",
    "from torch.utils.data import DataLoader\n",
    "from tqdm import tqdm\n",
    "\n",
    "from op_ds.gno.gno import GNOLayer, GNO\n",
    "from op_ds.gno.kernel import NonlinearKernelTransformWithSkip\n",
    "from op_ds.utils.fnn import FNN\n",
    "from volatility_smoothing.utils.gno.train import Trainer\n",
    "from volatility_smoothing.utils.options_data import CBOEOptionsDataset\n",
    "from volatility_smoothing.utils.gno.dataset import GNOOptionsDataset\n",
    "from volatility_smoothing.utils.chunk import chunked"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Device"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "logger.info(f\"Defining device (torch.cuda.is_available()={torch.cuda.is_available()})\")\n",
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "logger.info(f'Running using device `{device}`')\n",
    "\n",
    "if device.type == 'cuda':\n",
    "    result = subprocess.run(['nvidia-smi'], stdout=subprocess.PIPE)\n",
    "    formatted_result = str(result.stdout).replace('\\\\n', '\\n').replace('\\\\t', '\\t')##\n",
    "\n",
    "    logger.info(formatted_result)\n",
    "    logger.info(f'Device count: {torch.cuda.device_count()}')\n",
    "    logger.info(f'Visible devices count: {os.environ[\"CUDA_VISIBLE_DEVICES\"]}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_dir = \"../data/cboe/train\"\n",
    "dev_dir = \"../data/cboe/dev\"\n",
    "\n",
    "train_dataset = CBOEOptionsDataset(cache_dir=train_dir)\n",
    "dev_dataset = CBOEOptionsDataset(cache_dir=dev_dir)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "in_channels = 1\n",
    "out_channels = 1\n",
    "channels = (in_channels, 16, 16, 16, out_channels)\n",
    "spatial_dim = 2\n",
    "gno_channels = 16\n",
    "hidden_channels = 64\n",
    "\n",
    "gno_layers = []\n",
    "\n",
    "for i in range(m := (len(channels) - 1)):\n",
    "    lifting = FNN.from_config((channels[i], hidden_channels, gno_channels), hidden_activation='gelu', batch_norm=False)\n",
    "    projection = None if i < m - 1 else FNN.from_config((gno_channels, hidden_channels, channels[i+1]), hidden_activation='gelu', batch_norm=False)\n",
    "    transform = NonlinearKernelTransformWithSkip(in_channels=gno_channels, out_channels=gno_channels, skip_channels=in_channels, spatial_dim=spatial_dim, hidden_channels=(hidden_channels, hidden_channels), hidden_activation='gelu', batch_norm=False)\n",
    "\n",
    "    if i == 0:\n",
    "        local_linear = False\n",
    "    else:\n",
    "        local_linear = True\n",
    "        \n",
    "    activation = torch.nn.GELU() if i < m - 1 else torch.nn.Softplus(beta=0.5)\n",
    "        \n",
    "    gno_layer = GNOLayer(gno_channels, transform=transform, local_linear=local_linear, local_bias=True,\n",
    "                         activation=activation, lifting=lifting, projection=projection)\n",
    "    gno_layers.append(gno_layer)\n",
    "    \n",
    "gno = GNO(*gno_layers, in_channels=in_channels).to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "logger.info(gno)\n",
    "logger.info(sum([p.numel() for p in gno.parameters()]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training Hyperparameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "error_weights = {'fit': 1., 'but': 10., 'cal': 10., 'reg_z': 0.01, 'reg_r': 0.01}\n",
    "lr = 1e-4\n",
    "weight_decay = 1e-5\n",
    "epochs = 500\n",
    "batch_size = 64"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer = Trainer()  # Use default parameters\n",
    "gno_train_dataset = GNOOptionsDataset(train_dataset, subsample=True)\n",
    "gno_dev_dataset = GNOOptionsDataset(dev_dataset, subsample=False)\n",
    "optimizer = torch.optim.AdamW(gno.parameters(), lr=lr, weight_decay=weight_decay)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_workers = 16"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.cuda.empty_cache()\n",
    "\n",
    "logger.info(50 * \"=\")\n",
    "logger.info(f\"Training start (Epochs: {epochs}).\")\n",
    "logger.info(50 * \"=\")\n",
    "\n",
    "idx_list = list(range(len(gno_train_dataset)))\n",
    "model = gno.to(device).train()\n",
    "try:\n",
    "    for epoch in range(epochs):\n",
    "\n",
    "        logger.info(f\"Loss weights: {trainer.error_weights}\")\n",
    "        \n",
    "        shuffle(idx_list)\n",
    "        \n",
    "        dataloader = DataLoader(gno_train_dataset, batch_size=1, collate_fn=trainer.collate_fn, shuffle=True, num_workers=num_workers, pin_memory=False)\n",
    "        its = iter(dataloader)\n",
    "\n",
    "        for count, batch_idx in zip(range(len(idx_list)), (iterations := tqdm(chunked(idx_list, batch_size)))):\n",
    "            \n",
    "            bs = len(batch_idx)\n",
    "            \n",
    "            optimizer.zero_grad()\n",
    "            \n",
    "            batch_loss = 0\n",
    "\n",
    "            loss_infos = []\n",
    "            loss_str = []\n",
    "            \n",
    "            for i in batch_idx:\n",
    "                \n",
    "                data, input, aux = next(its)\n",
    "                data = data.to(device)\n",
    "                input = {key: val.to(device) for key, val in input.items()}\n",
    "                aux['grids'] = [grid.to(device) for grid in aux['grids']]\n",
    "                \n",
    "                output = model(**input)\n",
    "                sample_loss, sample_loss_info = trainer.loss(data, output, aux)                \n",
    "                sample_loss = sample_loss / bs\n",
    "                sample_loss.backward()\n",
    "                \n",
    "                batch_loss =  batch_loss + sample_loss\n",
    "                loss_infos.append(sample_loss_info)\n",
    "                \n",
    "            loss_details = {k: [loss_info[k] for loss_info in loss_infos] for k in loss_infos[0]}\n",
    "            loss_str.append(f\"mape: {sum(loss_details['mape']) / bs :> 10.3g}\")\n",
    "            loss_str.append(f\"wmape: {sum(loss_details['wmape']) / bs :> 10.3g}\")\n",
    "            loss_str.append(f\"vol pen: {sum(loss_details['fit']) / bs :> 10.3g}\")\n",
    "            loss_str.append(f\"cal pen: {sum(loss_details['cal']) / bs :> 10.3g}\")\n",
    "            loss_str.append(f\"but pen: {sum(loss_details['but']) / bs :> 10.3g}\")\n",
    "            loss_str.append(f\"reg_z pen: {sum(loss_details['reg_z']) / bs :> 10.3g}\")\n",
    "            loss_str.append(f\"reg_r pen: {sum(loss_details['reg_r']) / bs :> 10.3g}\")\n",
    "                                        \n",
    "            loss_s = f\"loss: {batch_loss: .8f} ({', '.join(loss_str)})\"\n",
    "            iterations.set_description(loss_s)\n",
    "\n",
    "            if (iterations.n % 10 == 0) and (storedir is not None):\n",
    "                logger.info(f\"Epoch {epoch}; {iterations.n}/{len(iterations)} -- {loss_s}\")                                \n",
    "                \n",
    "            optimizer.step()\n",
    "  \n",
    "        df_val, df_rel, df_fit = trainer.evaluate(model, gno_dev_dataset, device=device, num_workers=num_workers)\n",
    "        logger.info(f\"Epoch {epoch} Dev: {df_val.describe()}\")\n",
    "        df_val.to_csv(f\"{checkpoint_storedir}/val_{epoch}.csv\")\n",
    "        model.train()\n",
    "\n",
    "        if checkpoint_storedir is not None:\n",
    "            checkpoint = {\n",
    "                'model': model.state_dict(),\n",
    "                'optimizer': optimizer.state_dict(),\n",
    "            }\n",
    "            torch.save(checkpoint, f\"{checkpoint_storedir}/checkpoint_{epoch}.pt\")\n",
    "\n",
    "except KeyboardInterrupt:\n",
    "    try:\n",
    "        batch_loss.backward()\n",
    "    except:\n",
    "        pass\n",
    "    logging.info(\"Training aborted\")\n",
    "else:\n",
    "    logging.info(\"Training complete\")\n",
    "finally:\n",
    "    if checkpoint_storedir is not None:\n",
    "        checkpoint = {\n",
    "            'model': model.state_dict(),\n",
    "            'optimizer': optimizer.state_dict(),\n",
    "        }\n",
    "        torch.save(checkpoint, f\"{checkpoint_storedir}/checkpoint_final.pt\")\n",
    "    model.eval()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "op-ds-cqZ6S183-py3.11",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
