{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import logging\n",
    "import json\n",
    "import subprocess\n",
    "from pathlib import Path\n",
    "import fnmatch\n",
    "from random import shuffle\n",
    "import tempfile"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Store Directory"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "storedir = None  # Set this to save evaluation results/checkpoints"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "if storedir is not None:\n",
    "    checkpoint_storedir = f\"{storedir}/checkpoints\"\n",
    "    Path(checkpoint_storedir).mkdir(exist_ok=True)\n",
    "\n",
    "    data_storedir = f\"{storedir}/data\"\n",
    "    Path(data_storedir).mkdir(exist_ok=True)\n",
    "else:\n",
    "    checkpoint_storedir = None\n",
    "    data_storedir = None\n",
    "    \n",
    "try:\n",
    "    job_id = os.environ['PBS_JOBID'].split('.pbs')[0]\n",
    "except KeyError:\n",
    "    job_id = 'local'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "logging.basicConfig()\n",
    "logger = logging.getLogger('job')\n",
    "logger.setLevel(logging.INFO)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "logger.info('Importing third-party packages ...')\n",
    "\n",
    "import torch\n",
    "from torch.utils.data import DataLoader\n",
    "from tqdm import tqdm\n",
    "\n",
    "from op_ds.gno.gno import GNOLayer, GNO\n",
    "from op_ds.gno.kernel import NonlinearKernelTransformWithSkip\n",
    "from op_ds.utils.fnn import FNN\n",
    "from volatility_smoothing.utils.gno.train import Trainer\n",
    "from volatility_smoothing.utils.options_data import CBOEOptionsDataset\n",
    "from volatility_smoothing.utils.gno.dataset import GNOOptionsDataset\n",
    "from volatility_smoothing.utils.chunk import chunked\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Device"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "logger.info(f\"Defining device (torch.cuda.is_available()={torch.cuda.is_available()})\")\n",
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "logger.info(f'Running using device `{device}`')\n",
    "\n",
    "if device.type == 'cuda':\n",
    "    result = subprocess.run(['nvidia-smi'], stdout=subprocess.PIPE)\n",
    "    formatted_result = str(result.stdout).replace('\\\\n', '\\n').replace('\\\\t', '\\t')##\n",
    "\n",
    "    logger.info(formatted_result)\n",
    "    logger.info(f'Device count: {torch.cuda.device_count()}')\n",
    "    logger.info(f'Visible devices count: {os.environ[\"CUDA_VISIBLE_DEVICES\"]}')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_dir = \"../data/cboe\"\n",
    "train_dir = f\"{data_dir}/train\"\n",
    "dev_dir = f\"{data_dir}/dev\"\n",
    "test_dirs = sorted([f'{data_dir}/' + match for match in fnmatch.filter(os.listdir(data_dir), 'test_*')])\n",
    "\n",
    "\n",
    "def read_filepaths(dir):\n",
    "    return [f\"{dir}/{filename}\" for filename in fnmatch.filter(os.listdir(dir), '*.pt')]\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_dataset = CBOEOptionsDataset(cache_dir=train_dir)\n",
    "dev_dataset = CBOEOptionsDataset(cache_dir=dev_dir)\n",
    "test_datasets = [CBOEOptionsDataset(cache_dir=test_dir) for test_dir in test_dirs]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "in_channels = 1\n",
    "out_channels = 1\n",
    "channels = (in_channels, 16, 16, 16, out_channels)\n",
    "spatial_dim = 2\n",
    "gno_channels = 16\n",
    "hidden_channels = 64\n",
    "\n",
    "gno_layers = []\n",
    "\n",
    "for i in range(m := (len(channels) - 1)):\n",
    "    lifting = FNN.from_config((channels[i], hidden_channels, gno_channels), hidden_activation='gelu', batch_norm=False)\n",
    "    projection = None if i < m - 1 else FNN.from_config((gno_channels, hidden_channels, channels[i+1]), hidden_activation='gelu', batch_norm=False)\n",
    "    transform = NonlinearKernelTransformWithSkip(in_channels=gno_channels, out_channels=gno_channels, skip_channels=in_channels, spatial_dim=spatial_dim, hidden_channels=(hidden_channels, hidden_channels), hidden_activation='gelu', batch_norm=False)\n",
    "\n",
    "    if i == 0:\n",
    "        local_linear = False\n",
    "    else:\n",
    "        local_linear = True\n",
    "        \n",
    "    activation = torch.nn.GELU() if i < m - 1 else torch.nn.Softplus(beta=0.5)\n",
    "        \n",
    "    gno_layer = GNOLayer(gno_channels, transform=transform, local_linear=local_linear, local_bias=True,\n",
    "                         activation=activation, lifting=lifting, projection=projection)\n",
    "    gno_layers.append(gno_layer)\n",
    "    \n",
    "gno = GNO(*gno_layers, in_channels=in_channels).to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_checkpoint(model, optimizer, path):\n",
    "    checkpoint = torch.load(path, map_location=device)\n",
    "    model.load_state_dict(checkpoint['model'])\n",
    "    optimizer.load_state_dict(checkpoint['optimizer'])\n",
    "    logger.info(f\"Loaded checkpoint from {path}\")\n",
    "    return model, optimizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "path = \"../train/store/9448705/checkpoints/checkpoint_final.pt\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "optimizer = torch.optim.AdamW(gno.parameters())\n",
    "gno, optimizer = load_checkpoint(gno, optimizer, path)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Evaluation/Finetuning Hyperparameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "lr = 1e-4\n",
    "weight_decay = 1e-5\n",
    "epochs = 10  # Finetune epochs, set to 0 to skip and just evaluate\n",
    "batch_size = 64  # Finetune batch size, will be augmented by same amount of training data\n",
    "\n",
    "# mesh sizes on which to evaluate arbitrage metrics\n",
    "step_r = 0.02\n",
    "step_z = 0.01"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer = Trainer()  # Use default parameters\n",
    "gno_train_dataset = GNOOptionsDataset(train_dataset, subsample=True)\n",
    "gno_dev_dataset = GNOOptionsDataset(dev_dataset, subsample=False)\n",
    "gno_test_datasets = [GNOOptionsDataset(test_dataset, subsample=False) for test_dataset in test_datasets]\n",
    "optimizer = torch.optim.AdamW(gno.parameters(), lr=lr, weight_decay=weight_decay)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_workers = 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_val, df_rel, df_fit = trainer.evaluate(gno, gno_dev_dataset, device=device,\n",
    "                                  num_workers=num_workers, storedir=storedir, logger=logger,\n",
    "                                  step_r=step_r, step_z=step_z)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.cuda.empty_cache()\n",
    "\n",
    "logger.info(50 * \"=\")\n",
    "logger.info(f\"Evaluation start (Retraining epochs: {epochs}).\")\n",
    "logger.info(50 * \"=\")\n",
    "\n",
    "\n",
    "\n",
    "with tempfile.TemporaryDirectory() as tmpdir:  # Create empty temporary directory (to init with empty list)\n",
    "    gno_finetune_dataset = GNOOptionsDataset(CBOEOptionsDataset(cache_dir=tmpdir), subsample=True)\n",
    "\n",
    "model = gno.to(device).train()\n",
    "try:\n",
    "    for k, gno_test_dataset in enumerate(gno_test_datasets):\n",
    "\n",
    "        logger.info(f\"Evaluating model\")\n",
    "        model.eval()\n",
    "        df_val, df_rel, df_fit  = trainer.evaluate(model, gno_test_dataset, device=device,\n",
    "                                num_workers=num_workers, storedir=storedir, logger=logger,\n",
    "                                step_r=step_r, step_z=step_z)\n",
    "    \n",
    "        model.train()\n",
    "\n",
    "        gno_finetune_dataset.options_dataset._data += gno_test_dataset.options_dataset._data\n",
    "        idx_list = list(range(len(gno_finetune_dataset)))\n",
    "        \n",
    "        for epoch in range(epochs):\n",
    "\n",
    "            logger.info(f'Finetune step {k}. Epoch {epoch}/{epochs}')\n",
    "            logger.info(f\"Loss weights: {trainer.error_weights}\")\n",
    "        \n",
    "            shuffle(idx_list)\n",
    "        \n",
    "            t_dataloader = DataLoader(gno_train_dataset, batch_size=1, collate_fn=trainer.collate_fn, shuffle=True, num_workers=num_workers, pin_memory=False)\n",
    "            f_dataloader = DataLoader(gno_finetune_dataset, batch_size=1, collate_fn=trainer.collate_fn, shuffle=True, num_workers=num_workers, pin_memory=False)\n",
    "\n",
    "            t_its = iter(t_dataloader)\n",
    "            f_its = iter(f_dataloader)\n",
    "            \n",
    "            for count, batch_idx in zip(range(len(idx_list)), (iterations := tqdm(chunked(idx_list, batch_size)))):\n",
    "            \n",
    "                bs = 2 * len(batch_idx)\n",
    "            \n",
    "                optimizer.zero_grad()\n",
    "            \n",
    "                batch_loss = 0\n",
    "\n",
    "                loss_infos = []\n",
    "                loss_str = []\n",
    "            \n",
    "                for i in batch_idx:\n",
    "                \n",
    "                    # Base data\n",
    "                    data, input, aux = next(t_its)\n",
    "                    data = data.to(device)\n",
    "                    input = {key: val.to(device) for key, val in input.items()}\n",
    "                    aux['grids'] = [grid.to(device) for grid in aux['grids']]\n",
    "                \n",
    "                    output = model(**input)\n",
    "                    sample_loss, sample_loss_info = trainer.loss(data, output, aux)                \n",
    "                    sample_loss = sample_loss / bs\n",
    "                    sample_loss.backward()\n",
    "                \n",
    "                    batch_loss =  batch_loss + sample_loss\n",
    "                    loss_infos.append(sample_loss_info)\n",
    "                \n",
    "                    # Finetune data\n",
    "                    data, input, aux = next(f_its)\n",
    "                    data = data.to(device)\n",
    "                    input = {key: val.to(device) for key, val in input.items()}\n",
    "                    aux['grids'] = [grid.to(device) for grid in aux['grids'] if grid is not None]\n",
    "\n",
    "                    output = model(**input)\n",
    "                    sample_loss, sample_loss_info = trainer.loss(data, output, aux)                \n",
    "                    sample_loss = sample_loss / bs\n",
    "                    sample_loss.backward()\n",
    "\n",
    "                    batch_loss = batch_loss + sample_loss\n",
    "                    loss_infos.append(sample_loss_info)\n",
    "\n",
    "                loss_details = {k: [loss_info[k] for loss_info in loss_infos] for k in loss_infos[0]}\n",
    "                loss_str.append(f\"mape: {sum(loss_details['mape']) / bs :> 10.3g}\")\n",
    "                loss_str.append(f\"wmape: {sum(loss_details['wmape']) / bs :> 10.3g}\")\n",
    "                loss_str.append(f\"fit pen: {sum(loss_details['fit']) / bs :> 10.3g}\")\n",
    "                loss_str.append(f\"cal pen: {sum(loss_details['cal']) / bs :> 10.3g}\")\n",
    "                loss_str.append(f\"but pen: {sum(loss_details['but']) / bs :> 10.3g}\")\n",
    "                loss_str.append(f\"reg_z pen: {sum(loss_details['reg_z']) / bs :> 10.3g}\")\n",
    "                loss_str.append(f\"reg_r pen: {sum(loss_details['reg_r']) / bs :> 10.3g}\")\n",
    "                                        \n",
    "                loss_s = f\"loss: {batch_loss: .8f} ({', '.join(loss_str)})\"\n",
    "                iterations.set_description(loss_s)\n",
    "\n",
    "                if (iterations.n % 10 == 0) and (storedir is not None):\n",
    "                    logger.info(f\"{k}-{epoch}-{len(iterations)}-{iterations.n} -- {loss_s}\")                                \n",
    "                \n",
    "                optimizer.step()\n",
    "\n",
    "            df_val, df_rel, df_fit = trainer.evaluate(model, gno_dev_dataset, device=device, num_workers=num_workers)\n",
    "            logger.info(f\"Epoch {epoch} Dev: {df_val.describe()}\")\n",
    "            df_val.to_csv(f\"{checkpoint_storedir}/val_{k}-{epoch}.csv\")\n",
    "            model.train()\n",
    "\n",
    "            if checkpoint_storedir is not None:\n",
    "                checkpoint = {\n",
    "                    'model': model.state_dict(),\n",
    "                    'optimizer': optimizer.state_dict(),\n",
    "                }\n",
    "            torch.save(checkpoint, f\"{checkpoint_storedir}/checkpoint_{k}-{epoch}.pt\")\n",
    "\n",
    "except KeyboardInterrupt:\n",
    "    try:\n",
    "        batch_loss.backward()\n",
    "    except:\n",
    "        pass\n",
    "    logging.info(\"Training aborted\")\n",
    "else:\n",
    "    logging.info(\"Training complete\")\n",
    "finally:\n",
    "    if checkpoint_storedir is not None:\n",
    "        checkpoint = {\n",
    "            'model': model.state_dict(),\n",
    "            'optimizer': optimizer.state_dict(),\n",
    "        }\n",
    "        torch.save(model, f\"{checkpoint_storedir}/checkpoint_final.pt\")\n",
    "    model.eval()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "op-ds-cqZ6S183-py3.11",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
