{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Learning Limit Order Book Dynamics with NSPDE.\n",
    "\n",
    "In this notebook we train the NSPDE model by minimising an expected signature kernel score to learn limit order book dynamics."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%cd ../"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "is_executing": true
    }
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import os\n",
    "import scipy\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from tqdm import tqdm\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "from IPython import display\n",
    "from src.gan.discriminators_spde import SigKerMMDDiscriminator, ExpectedSigKerScoreDiscriminator\n",
    "from src.gan.generators_spde import Generator\n",
    "from src.gan.output_functions import plot_loss\n",
    "from src.gan.base import stopping_criterion\n",
    "from statsmodels.tsa.stattools import acf\n",
    "from src.evaluation_functions import generate_ks_results_nspde"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "is_cuda = torch.cuda.is_available()\n",
    "device = 'cuda' if is_cuda else 'cpu'\n",
    "\n",
    "if not is_cuda:\n",
    "    print(\"Warning: CUDA not available; falling back to CPU but this is likely to be very slow.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. Get data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = np.load('../data/lob_neurips.npy')\n",
    "data = torch.tensor(data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dim_x = 20\n",
    "dim_t = 32\n",
    "output_dim = 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "gridT = torch.linspace(0,1,dim_t).repeat(dim_x,1)\n",
    "gridX = torch.linspace(0,1,dim_x).unsqueeze(-1).repeat(1,dim_t)\n",
    "\n",
    "grid = torch.stack([gridX, gridT], dim=-1).to(device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. Begin training\n",
    "\n",
    "In this section we set the training parameters for the GAN. Each parameter is annotated."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Generator arguments\n",
    "dim = 1\n",
    "initial_noise_size = 1              # How many noise dimensions to sample at the start of the SDE.\n",
    "noise_size         = 2              # How many dimensions the Wiener process motion has.\n",
    "hidden_size        = 16             # How big the hidden state of the generator NSPDE is.\n",
    "noise_type         = \"white\"        # Noise type argument for torchspde\n",
    "integration_method = \"fixed_point\"  # Integration method to solve the latent SPDE\n",
    "fixed              = True           # Whether to fix the starting point or not\n",
    "data_size          = data.shape[-1]\n",
    "modes1             = 10             # Number of modes to perform convolution in the temporal dimension\n",
    "modes2             = 20             # Number of modes to perform convolution in the spatial dimension\n",
    "n_iter             = 4              # Number of iterations to solve the fixed point problem (if integration_method is \"fixed_point\")\n",
    "\n",
    "# Discriminator args\n",
    "dyadic_order   = 1                  # Mesh size of PDE solver used in loss function\n",
    "kernel         = \"rbf_id\"           # Type of kernel to use in the discriminator\n",
    "use_phi_kernel = False              # Whether we want to take averages of signature kernels \n",
    "n_scalings     = 3                  # Number of kernels to average\n",
    "sigma         = {'sigma':10}        # hyperparameters of the kernel\n",
    "adversarial   = False               # Whether to adversarially train the discriminator or not.\n",
    "max_batch     = 32                  # Maximum batch size to pass through the discriminator.\n",
    "loss_evals    = 1                   # Number of evaluations before doing gradient step\n",
    "if not adversarial:\n",
    "    discriminator_type = \"scoring\"\n",
    "else:\n",
    "    discriminator_type = \"mmd\"\n",
    "\n",
    "# Training hyperparameters\n",
    "generator_lr     = 1e-03         # Generator initial learning rate\n",
    "generator_mom    = 0.            # (Optional) momentum parameter for generator\n",
    "discriminator_lr = 3e-03         # Discriminator initial learning rate\n",
    "disriminator_mom = 0.            # (Optional) momentum parameter for discriminator\n",
    "batch_size       = 32            # Batch size (set above in the data extraction)\n",
    "steps            = 2000          # How many steps to train both generator and discriminator for.\n",
    "gen_optim        = \"Adam\"        # Optimiser for generator\n",
    "\n",
    "weight_decay     = 5*1e-04       # Weight decay.\n",
    "disc_optim       = \"Adam\"        # Optimiser for discriminator\n",
    "\n",
    "# Evaluation and plotting hyperparameters\n",
    "steps_per_print  = 50             # How often to print the loss.\n",
    "update_freq  = 50"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataloader = torch.utils.data.DataLoader(data, batch_size=batch_size, shuffle=True)\n",
    "infinite_train_dataloader = (elem for it in iter(lambda: dataloader, None) for elem in it)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2.2 Init Generator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# create correlation function for the noise in input to the generator\n",
    "def my_smooth_corr(x, a, r = 2):\n",
    "    my_eps=0.001\n",
    "    j = 1.*torch.arange(1,x.shape[0]+1).to(x.device)\n",
    "    j[-1] = 0.\n",
    "    q = j**(-(2*r+1+my_eps)/2)\n",
    "    q[-1]=0\n",
    "    res = torch.sqrt(q)*torch.sqrt(2. / a) * torch.sin(j * torch.pi * x / a)\n",
    "    return res\n",
    "\n",
    "if noise_type[0] == 'r':\n",
    "    input_roughness = int(noise_type.split('_')[-1])\n",
    "    noise_type = lambda x,a : my_smooth_corr(x, a, r = input_roughness)\n",
    "# Initialise the generator\n",
    "generator = Generator(\n",
    "    dim=dim,\n",
    "    data_size=output_dim, \n",
    "    initial_noise_size=initial_noise_size,\n",
    "    noise_size=noise_size, \n",
    "    hidden_size=hidden_size, \n",
    "    initial_point='given',\n",
    "    noise_type=noise_type,\n",
    "    integration_method=integration_method,\n",
    "    modes1=modes1, \n",
    "    modes2=modes2,\n",
    "    n_iter=n_iter,\n",
    ").to(device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2.3 Init Discriminator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Initialise the discriminator\n",
    "if discriminator_type.lower() == \"scoring\":\n",
    "    discriminator = ExpectedSigKerScoreDiscriminator(\n",
    "        kernel_type = kernel,\n",
    "        dyadic_order = dyadic_order,\n",
    "        sigma=sigma,\n",
    "        path_dim = output_dim,\n",
    "        adversarial=adversarial,\n",
    "        max_batch=max_batch,\n",
    "        use_phi_kernel = use_phi_kernel,\n",
    "        n_scalings = n_scalings\n",
    "    ).to(device)\n",
    "else:\n",
    "    discriminator = SigKerMMDDiscriminator(\n",
    "        kernel_type=kernel, \n",
    "        dyadic_order=dyadic_order, \n",
    "        path_dim=output_dim, \n",
    "        sigma=sigma,\n",
    "        adversarial=adversarial,\n",
    "        max_batch=max_batch\n",
    "    ).to(device)\n",
    "    \n",
    "mmd = SigKerMMDDiscriminator(\n",
    "        kernel_type = kernel,\n",
    "        dyadic_order = dyadic_order,\n",
    "        sigma=sigma,\n",
    "        path_dim = output_dim,\n",
    "        adversarial=adversarial,\n",
    "        max_batch=max_batch,\n",
    "        use_phi_kernel = use_phi_kernel,\n",
    "        n_scalings = n_scalings\n",
    "    ).to(device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. Train the NSPDE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "generator_optimiser = torch.optim.Adam(generator.parameters(), lr=generator_lr)#, weight_decay=weight_decay)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "alpha = .1\n",
    "tr_loss = torch.zeros(steps, requires_grad=False).to(device)\n",
    "sigmas = torch.zeros(steps, requires_grad=False).to(device)\n",
    "step_vec = np.arange(steps)\n",
    "trange = tqdm(range(steps), position=0, leave=True)\n",
    "\n",
    "for step in trange:\n",
    "\n",
    "    loss = 0\n",
    "    for _ in range(loss_evals):\n",
    "        real_samples      = next(infinite_train_dataloader)\n",
    "        \n",
    "        if real_samples.shape[0]!=batch_size:\n",
    "            real_samples      = next(infinite_train_dataloader)\n",
    "        \n",
    "        real_samples = real_samples.to(device)\n",
    "        u0 = real_samples[:,0,:].permute(0,2,1).float()\n",
    "        generated_samples = generator(grid, batch_size, u0)\n",
    "        loss += discriminator(alpha*torch.cumsum(generated_samples,dim=1), alpha*torch.cumsum(real_samples,dim=1).detach())\n",
    "\n",
    "    loss /= loss_evals\n",
    "    loss.backward()\n",
    "\n",
    "    tr_loss[step] = loss.clone()\n",
    "\n",
    "    generator_optimiser.step()\n",
    "    generator_optimiser.zero_grad()\n",
    "    \n",
    "    if ((step % update_freq == 0) or (step == steps-1)):\n",
    "        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 10))\n",
    "        ax1_ = ax1.twinx() \n",
    "        # Plot updated loss\n",
    "        with torch.no_grad():\n",
    "            np_tr_loss = tr_loss.detach().cpu().numpy()\n",
    "            np_tr_sigma = sigmas.detach().cpu().numpy()\n",
    "\n",
    "        current_loss = np_tr_loss[:step]\n",
    "        current_sigma = np_tr_sigma[:step]\n",
    "        print(current_loss)\n",
    "        future_loss  = 0. if len(current_loss) == 0 else np.min(current_loss) - np.std(current_loss)\n",
    "        future_sigma  = 0. if len(current_sigma) == 0 else np.min(current_sigma) - np.std(current_sigma)\n",
    "        current_steps = step_vec[:step]\n",
    "        current_sigmas = step_vec[:step]\n",
    "        future_steps  = step_vec[step:]\n",
    "        future_loss   = np.array([future_loss for _ in range(future_steps.shape[0])])\n",
    "        future_sigma   = np.array([future_sigma for _ in range(future_steps.shape[0])])\n",
    "\n",
    "        ax1.plot(current_steps, current_loss, alpha=1., color=\"dodgerblue\", label=\"training_loss\")\n",
    "        ax1.plot(future_steps, future_loss, alpha=0.)\n",
    "        ax1_.plot(current_steps, current_sigma, alpha=1., color=\"red\", label=\"sigma\")\n",
    "        ax1_.plot(future_steps, future_sigma, alpha=0.)\n",
    "\n",
    "        ax1.grid(visible=True, color='grey', linestyle=':', linewidth=1.0, alpha=0.3)\n",
    "        ax1.minorticks_on()\n",
    "        ax1.grid(visible=True, which='minor', color='grey', linestyle=':', linewidth=1.0, alpha=0.1)\n",
    "        \n",
    "        ax1.legend()\n",
    "        \n",
    "        with torch.no_grad():\n",
    "            \n",
    "            to_plot = np.sort(np.random.choice(data.shape[0], size=1000, replace=False))\n",
    "            u0 = data[to_plot,0,:].to(device).permute(0,2,1).float()\n",
    "            generated = generator(grid, 1000, u0)\n",
    "            p = stopping_criterion(generated[:500,...,0], data[to_plot][:500,...,0], cutoff=1., tol=0.05)\n",
    "\n",
    "        ax2.hist(generated.cpu().detach().numpy()[:,5,5,0], color='red', alpha=0.5,bins=25, density=True)\n",
    "        ax2.hist(data[to_plot,5,5,0].numpy(), color='blue', alpha=0.5, bins=25, density=True)\n",
    "\n",
    "        ax2.grid(visible=True, color='grey', linestyle=':', linewidth=1.0, alpha=0.3)\n",
    "        ax2.minorticks_on()\n",
    "        ax2.grid(visible=True, which='minor', color='grey', linestyle=':', linewidth=1.0, alpha=0.1)\n",
    "        \n",
    "        ax2.legend()\n",
    "        \n",
    "        \n",
    "        display.clear_output(wait=True)\n",
    "        display.display(plt.gcf())\n",
    "    if (step % steps_per_print) == 0 or step == steps - 1:\n",
    "\n",
    "        trange.write(f\"Step: {step:3} Total loss (unaveraged): {loss.item():.5e} K-S passed: {p} \")\n",
    "\n",
    "# ###############################################################################\n",
    "# ## 5. Training complete\n",
    "# ################################b###############################################\n",
    "# torch.save(generator.state_dict(), get_project_root().as_posix() + \"/notebooks/models/generator.pkl\")\n",
    "# torch.save(discriminator.state_dict(), get_project_root().as_posix() + \"/notebooks/models/discriminator.pkl\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 4. Evaluation of the marginals"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "eval_batch_size = 128\n",
    "dataloader = torch.utils.data.DataLoader(data, batch_size=eval_batch_size, shuffle=True)\n",
    "infinite_train_dataloader = (elem for it in iter(lambda: dataloader, None) for elem in it)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from scipy.stats import ks_2samp\n",
    "from scipy import stats as st\n",
    "\n",
    "marginals = tuple([i*1./30 for i in range(30)])\n",
    "alpha     = 0.95\n",
    "tol       = 1 - alpha\n",
    "n_runs    = 100"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dims = dim_x\n",
    "path_length = dim_t\n",
    "generators = [generator]\n",
    "\n",
    "mean_ks = np.zeros((len(generators), dims, len(marginals)))\n",
    "\n",
    "\n",
    "total_ks_results = generate_ks_results_nspde(\n",
    "    grid, infinite_train_dataloader, generators, marginals, n_runs, dims=dims, eval_batch_size=eval_batch_size, device=device\n",
    ")\n",
    "\n",
    "for k in range(dims):\n",
    "    print(f\"Dimension: {k}\")\n",
    "    for i, m in enumerate(marginals):\n",
    "        print(f\"Marginal {int(path_length*m)}:\")\n",
    "\n",
    "        for j, disc in enumerate(['rbf_id']):\n",
    "\n",
    "            average_score  = np.mean(total_ks_results[j, :, k, i, 0])\n",
    "            std_score      = np.std(total_ks_results[j, :, k, i, 0])\n",
    "            percent_reject = sum(total_ks_results[j, :, k, i, 1] <= tol)/n_runs\n",
    "            \n",
    "            mean_ks[j,k,i] = average_score\n",
    "            \n",
    "            lci, hci = st.norm.interval(alpha, loc=average_score, scale=std_score)\n",
    "\n",
    "            print(f\"{disc}: Average KS score: {average_score:.4f}, \" \n",
    "                  f\"% reject: {percent_reject*100:.1f}, CI: {lci:.4f}, {hci:.4f}\")\n",
    "        print(\"\")\n",
    "        \n",
    "        \n",
    "            "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sns.heatmap(mean_ks[0], cmap = sns.color_palette(\"mako_r\", as_cmap=True))\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "NSPDE",
   "language": "python",
   "name": "nspde"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
