{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "# Training Neural SDEs with the signature kernel scoring rule: The conditional case"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this notebook, we show how one can use the signature kernel and the signature kernel scoring rule to train a generative model to learn conditional distributions on path space. \n",
    "\n",
    "#### Introduction.\n",
    "Suppose one observes path segments $x: [0, s] \\to \\mathbb{R}^d$ for some $0 < s< T$, and the associated resultant paths conditional on $x$ are given by $y: [s, T] \\to \\mathbb{R}^d$. One can build a bank of input-output pairs $\\{(x^i, y^i)\\}_{i=1}^N$ where $N$ is the dataset size. The problem is to train a generator to learn the conditional distributon $\\mathbb{P}_{X^{\\text{true}}}(\\cdot |x)$, where $x$ is a given path segment up until $s \\in [0, T]$. \n",
    "\n",
    "We will use a neural SDE as a generator. However, the architecture will need to be modified to admit path segments as conditions. In general, the generator will need to take the form \n",
    "\n",
    "\\begin{equation}\n",
    "    G: \\Theta \\times \\mathcal{Z} \\times C([0, s]; \\mathbb{R}^d) \\to C([s, T]; \\mathbb{R}^d),\n",
    "\\end{equation}\n",
    "\n",
    "where $(\\mathcal{Z}, \\mathbb{P}_\\mathcal{Z})$ is the associated latent space (here, taken to be Wiener space $(W, \\mathbb{W})$, as we do not enforce an initial condition distributon on paths sampled from $G$). \n",
    "\n",
    "#### Generator architectures. \n",
    "\n",
    "We propose the following architecture to conditionalize path data. For each conditioning path $x$, extract the order $N$ signature $S^N(x)$ and feed these to the generator. In theory the (untruncated) signature mapping is sufficient to deliniate between conditioning paths. However in practice one is not able to calculate this infinite-dimensional object. Thus there is a trade-off. \n",
    "\n",
    "The signature method can become prohibitive from a memory standpoint (especially as the dimensionality of the paths expands), however it does not increase the size of the parameter set and thus can be more efficent from a runtime perspective. Formally, the functions in the generator in a forward pass are all augmented with the values\n",
    "\n",
    "\\begin{equation*}\n",
    "    S^N(x) = (1, \\mathbb{X}^1(x), \\mathbb{X}^2(x), \\dots, \\mathbb{X}^N(x)),\n",
    "\\end{equation*}\n",
    "\n",
    "where\n",
    "\n",
    "\\begin{equation*}\n",
    "    \\mathbb{X}^k(x) = \\bigoplus_{i_1, \\dots, i_k \\in \\{1, \\dots, d\\}^k} \\int \\dots \\int_{0\\le t_1 \\le t_k\\le T} dx^{i_1}_{t_1}\\dots dx^{i^k}_{t_k}\n",
    "\\end{equation*}\n",
    "\n",
    "is the signature of level $k$ of the path $x$. In this way the neural networks defining the vector fields in the generator (for instance) are now mappings defined as \n",
    "\n",
    "\\begin{gather*}\n",
    "    \\mu_\\theta: [s, T] \\times \\mathbb{R}^y \\times \\mathbb{R}^{1 + d + d^2 + \\dots + d^N} \\to \\mathbb{R}^y, \\\\\n",
    "    \\sigma_\\theta: [s, T] \\times \\mathbb{R}^y \\times \\mathbb{R}^{1 + d + d^2 + \\dots + d^N} \\to \\mathbb{R}^{y \\times w}. \\\\\n",
    "\\end{gather*}\n",
    "\n",
    "Dimensionality issues can be avoided by using the log-signature. In the notebook, we provide options for different path transformations on the space of input and output paths, along with the order and type of signature taken on the conditioning paths.\n",
    "\n",
    "#### Discriminator and loss function.\n",
    "\n",
    "As we only observe one \"true\" path $y: [s, T] \\to \\mathbb{R}^d$ for each conditioning path $x$, one cannot traditional metrics on the space of measures on path space (K-L, MMD, Wasserstein). Attempts to generate counterfactuals from the observed paths rely on certain assumptions about the data which need not hold in practice, or in fact be directly refutably (stationary and non-Markovianity are two obvious ones). One could use a conditional version of the MMD$^1$ however this machinery relies on conditional mean embeddings$^2$</b> and thus a different approach is required. \n",
    "\n",
    "We propose training conditional NSDEs using the signature kernel scoring rule $\\phi_{\\text{sig}}$. In this setting, the training loss becomes\n",
    "\n",
    "\\begin{equation*}\n",
    "    \\min_\\theta \\mathcal{L}'(\\theta) \\quad \\text{where} \\quad \\mathcal{L}'(\\theta)  = \\min_\\theta \\mathbb{E}_{x \\sim \\mathbb{Q}}\\mathbb{E}_{y \\sim \\mathbb{P}_{X^{\\text{true}}}(\\cdot|x)}\\left[\\phi_{\\text{sig}}(\\mathbb{P}_{X^\\theta}(\\cdot|x), y)\\right].\n",
    "\\end{equation*}\n",
    "\n",
    "The loss over a batch $B = \\{(x^i, y^i)\\}_{i=1}^{n}$ is thus given by\n",
    "\n",
    "\\begin{equation*}\n",
    "    \\min_\\theta \\frac{1}{n}\\sum_{i=1}^n \\phi_{\\text{sig}}(\\mathbb{P}_{X^\\theta}(\\cdot|x^i), y^i).\n",
    "\\end{equation*}\n",
    "\n",
    "#### Final notes. \n",
    "\n",
    "This method is theoretically guaranteed, but can be quite tricky to train. Out Of Memory (OOM) errors often occur, primarily due to the <code>.backward()</code> of the loss function. The computational graph can explode quite quickly, especially in the signature method.\n",
    "\n",
    "The hyperparameters provided are able to be run with 8GB of GPU memory. \n",
    "\n",
    "_______________\n",
    "\n",
    "$^1$ See https://arxiv.org/pdf/1606.04218.pdf.\n",
    "\n",
    "$^2$ Terribly unstable computationally, due to matrix inversions.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "pycharm": {
     "is_executing": true
    }
   },
   "outputs": [],
   "source": [
    "RUN_PRE_CHECKS = False\n",
    "\n",
    "import math\n",
    "from collections import OrderedDict\n",
    "\n",
    "import torch\n",
    "import torch.optim.swa_utils as swa_utils\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import scienceplots\n",
    "import signatory\n",
    "\n",
    "from IPython import display\n",
    "from tqdm import tqdm\n",
    "\n",
    "from src.gan.base import preprocess_real_data, get_real_data, get_scheduler, \\\n",
    "    calculate_batch_conditional_scoring_loss, evaluate_conditional_scoring_loss\n",
    "from src.gan.generators import PathConditionalCDEGenerator, PathConditionalSigGenerator\n",
    "from src.gan.discriminators import SigKerScoreDiscriminator\n",
    "from src.gan.output_functions import plot_loss\n",
    "from src.utils.helper_functions.global_helper_functions import get_project_root, mkdir\n",
    "from src.utils.helper_functions.data_helper_functions import subtract_initial_point, \\\n",
    "    get_scalings, process_generator, batch_subtract_initial_point, normalize, inv_normalize\n",
    "from src.utils.helper_functions.plot_helper_functions import make_grid, plot_line_error_bars\n",
    "\n",
    "from src.utils.transformations import Transformer\n",
    "\n",
    "plt.style.use('science')\n",
    "plt.rcParams['axes.titlesize']  = 24\n",
    "plt.rcParams['xtick.labelsize'] = 18\n",
    "plt.rcParams['ytick.labelsize'] = 18\n",
    "plt.rcParams['axes.labelsize']  = 22\n",
    "plt.rcParams['text.usetex']         = True\n",
    "plt.rcParams['text.latex.preamble'] = r'\\usepackage{amsmath, amsfonts}'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<torch._C.Generator at 0x1aabd064050>"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.manual_seed(0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Layout. \n",
    "\n",
    "This notebook is broken into four sections: \n",
    "\n",
    "1) <b>Data:</b> The conditional measure to be learnt on pathspace. We present three examples:\n",
    "    1) $gbm_{1, 2, 3} \\to gbm_{1, 2, 3}$: Conditioning paths come from three separate measures. Output path has the same distribution as the conditioning path.\n",
    "    2) $gbm \\to gbm_{1, 2, 3}$ Conditioning paths come from the same measure. Output paths are built from a non-Markovian condition on the conditioning paths. Output paths have a different distribution to the conditioning path.\n",
    "    3) Forex data. Paths are extracted from real data, split uniformly at a chosen point.\n",
    "\n",
    "2) <b>Generator:</b> This will always be a neural SDE, augmented to accept path data as a conditioning variable, as per the Introduction to this notebook.\n",
    "\n",
    "3) <b>Discriminator:</b> We use batched scoring rules, as explained in the Introduction.\n",
    "\n",
    "4) <b>Training the GAN and evaluation:</b> The GAN is then trained. We evaluate the performance visually, and in the synthetic case we are directly able to verify whether the training has been successful or not.\n",
    "\n",
    "There is also a sub-section that can be run before training to check if the GAN is likely to successfully converge. You can force this to run by setting <code>RUN_CHECKS</code> to <code>True</code> in the Imports section."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "is_cuda = torch.cuda.is_available()\n",
    "device = 'cuda' if is_cuda else 'cpu'\n",
    "\n",
    "if not is_cuda:\n",
    "    print(\"Warning: CUDA not available; falling back to CPU but this is likely to be very slow.\")\n",
    "    \n",
    "# You realistially need GPU access (either natively or via cloud computing) to run this notebook."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 0. Configuration\n",
    "\n",
    "This is the universal configuration which sets the data loading, generator and discriminator hyperparameters, training types, and so on. Each value is annotated."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Data arguments\n",
    "## Data hyperparameters\n",
    "output_dim          = 1                     # Dimension of the outputs.\n",
    "batch_size          = 128                   # Size of minibatches when training.\n",
    "path_length         = 48                    # Number of timesteps in paths.\n",
    "dataset_size        = batch_size*512        # Number of paths in the train/test datasets.\n",
    "normalisation       = None                  # How to transform paths. Options are None, \"mean_var\", \"min_max\"\n",
    "scale               = 1e3                   # Float scale to each path dimension.\n",
    "time_add_type       = \"basic\"               # Controls how time is accounted for. \"basic\" -> 0, 1, 2..., \"realistic\" -> real times.\n",
    "\n",
    "## Data generation / dataset arguments\n",
    "data_type             = \"forex\"                      # \"syn_type1\", \"syn_type2\", \"forex\"\n",
    "change_pt             = int(2*path_length/3)         # When to split paths into conditional/resolvant pairs\n",
    "cond_length           = change_pt                    # Length of conditioning paths\n",
    "eval_length           = int(path_length - change_pt) # Length of conditioned paths\n",
    "tt_split              = 0.8                          # Train/test split\n",
    "\n",
    "filter_extremal_paths = False                # Filter out extremal values (by tv and terminal value)\n",
    "filter_extremal_pct   = 0.95                 # Quantile to filter out, if the above is TRUE\n",
    "\n",
    "forex_pairs           = [\"EURUSD\"]           # \"EURUSD\", \"USDJPY\", \"BTCUSD\", \"BRENTCMDUSD\"\n",
    "stride_length         = 1                    # Number of timesteps to stride when creating path data.\n",
    "frequency             = \"M15\"                # Choice of \"H1\", \"M15\", \"M30\"\n",
    "\n",
    "drift_values          = [0.1, 0. , -0.1]     # mu0, mu1, mu2 for synthetic example\n",
    "diffusion_values      = [0.2, 0.2, 0.2]      # sig0, sig1, sig2 for synthetic example\n",
    "process_type          = \"bm\"                 # \"gbm\" or \"bm\"\n",
    "syn2_mask_type        = \"qv\"                 # \"tv\" or \"qv\" (total or quadratic variation)\n",
    "dt                    = 1e0                  # Timestep size for simulation. Note T = path_length*dt.\n",
    "S0                    = 0.                   # Starting value\n",
    "\n",
    "# Path transformation types/arguments\n",
    "do_transforms            = True  # Whether to apply path transformations\n",
    "\n",
    "cond_transformations     = OrderedDict([\n",
    "    (\"scale\"             , False),\n",
    "    (\"visibility\"        , False), \n",
    "    (\"time_difference\"   , False), \n",
    "    (\"time_normalisation\", True),\n",
    "    (\"lead_lag\"          , True), \n",
    "    (\"basepoint\"         , False)\n",
    "])\n",
    "\n",
    "cond_transformation_args = OrderedDict([\n",
    "    (\"scale\"             , {\"scaler\": 1e-1}),\n",
    "    (\"visibility\"        , {}), \n",
    "    (\"time_difference\"   , {}), \n",
    "    (\"time_normalisation\", {}), \n",
    "    (\"lead_lag\"          , {\"time_in\": True, \"time_out\": False, \"time_normalisation\": False}), \n",
    "    (\"basepoint\"         , {})\n",
    "])\n",
    "\n",
    "out_transformations     = OrderedDict([\n",
    "    (\"scale\"             , False),\n",
    "    (\"visibility\"        , False), \n",
    "    (\"time_difference\"   , False), \n",
    "    (\"time_normalisation\", True), \n",
    "    (\"lead_lag\"          , False), \n",
    "    (\"basepoint\"         , False)\n",
    "])\n",
    "\n",
    "out_transformation_args = OrderedDict([\n",
    "    (\"scale\"             , {\"scaler\": 1e-1}),\n",
    "    (\"visibility\"        , {}), \n",
    "    (\"time_difference\"   , {}), \n",
    "    (\"time_normalisation\", {}), \n",
    "    (\"lead_lag\"          , {\"time_in\": True, \"time_out\": False, \"time_normalisation\": False}), \n",
    "    (\"basepoint\"         , {})\n",
    "])\n",
    "\n",
    "subtract_start = True  # You almost always want this to be TRUE\n",
    "\n",
    "# Generator arguments\n",
    "## Unconditional_arguments\n",
    "generator_config = {\n",
    "    \"noise_size\"         : 8,                 # How many dimensions the Brownian motion has.\n",
    "    \"hidden_size\"        : 16,                # Size of the hidden state of the generator SDE.\n",
    "    \"mlp_size\"           : 64,                # Size of the layers in the various MLPs.\n",
    "    \"num_layers\"         : 3,                 # Numer of hidden layers in the various MLPs.\n",
    "    \"activation\"         : \"LipSwish\",        # Activation function to use over hidden layers\n",
    "    \"tanh\"               : True,              # Whether to apply final tanh activation\n",
    "    \"tscale\"             : 1,                 # Clip parameter for tanh, i.e. [-1, 1] to [-c, c]\n",
    "    \"fixed\"              : True,              # Whether to fix the starting point or not\n",
    "    \"noise_type\"         : \"general\",         # Noise type argument for torchsde\n",
    "    \"sde_type\"           : \"ito\",             # SDE integration type from torchsde\n",
    "    \"dt_scale\"           : 1e0,               # Grid shrinking parameter. Lower values are computationally more expensive\n",
    "    \"integration_method\" : \"euler\"            # Integration method for torchsde\n",
    "}\n",
    "\n",
    "## Conditional arguments\n",
    "\n",
    "emp_size         = 32                # Size of empirical measure made by generator\n",
    "\n",
    "conditional_config = {\n",
    "    \"logsig\"          : True,        # Use logsignature in conditional generator\n",
    "    \"order\"           : 5,           # Order of signature to take \n",
    "    \"sig_scale\"       : 1e0,         # Constant x which scales sig level k by x^k\n",
    "}\n",
    "\n",
    "\n",
    "# Discriminator arguments\n",
    "adversarial           = False         # Whether to adversarially train the discriminator or not.\n",
    "clip_disc_param       = True          # MMD-based discriminators only: whether to ensure the scaling param stays above 1   \n",
    "\n",
    "\n",
    "discriminator_args = {\n",
    "    \"dyadic_order\"       : 1,         # Mesh size of PDE solver used in loss function\n",
    "    \"kernel_type\"        : \"rbf\",     # Type of kernel to use in the discriminator\n",
    "    \"sigma\"              : 1e0,       # Sigma in RBF kernel\n",
    "    \"use_phi_kernel\"     : False,     # Whether to use the the phi(k) = (k/2)! scaling.\n",
    "    \"n_scalings\"         : 10,        # Number of samples to draw from Exp(1)\n",
    "    \"max_batch\"          : 32         # Maximum batch size to pass through the discriminator.\n",
    "}\n",
    "\n",
    "\n",
    "# Training hyperparameters\n",
    "## Optimiser parameters\n",
    "generator_lr     = 2e-06         # Generator initial learning rate\n",
    "discriminator_lr = 1e-02         # Discriminator initial learning rate\n",
    "steps            = 10000         # How many steps to train both generator and discriminator for.\n",
    "init_mult1       = 1             # Changing the initial parameter size can help.\n",
    "init_mult2_dr    = 2e0           # Changing vector field MLP initial parameter size.\n",
    "init_mult2_df    = 1e0           # Changing vector field MLP initial parameter size.\n",
    "init_mult3       = 1             # Initial parameter size for discriminator\n",
    "weight_decay     = 0.01          # Weight decay.\n",
    "swa_step_start   = int(steps/2)  # When to start using stochastic weight averaging.\n",
    "gen_optim        = \"Adam\"        # Optimiser for generator\n",
    "disc_optim       = \"Adam\"        # Optimiser for discriminator\n",
    " \n",
    "## Learning rate annealear arguments\n",
    "adapting_lr            = False           # Whether to make the learning rate adaptive.\n",
    "adapting_lr_type       = \"StepLR\"        # LR scheduler type\n",
    "lambda_lr_const        = 0.5             # If LambdaLR, learning rate fraction to reduce to\n",
    "poly_exponent_smoother = -0.5            # If LambdaLR, poly exponent to decrease to\n",
    "mult_const             = 1.01            # If MultiplicativeLR, val of a for \\eta_{t+1} = a\\eta_t\n",
    "gamma_lr               = 0.5             # If StepLR, multiplier for learning rate\n",
    "steps_lr               = int(2*steps/3)  # If StepLR, when to change learning rate\n",
    "max_lr                 = 1e-06           # OneCycleLR: Maximum rate\n",
    "anneal_strategy        = \"cos\"           # OneCycleLR: Anneal type \n",
    "total_steps            = steps           # OneCycleLR: Anneal rate\n",
    "pct_start              = 0.3             # OneCycleLR: Percentage of schedule increasing rate\n",
    "div_factor             = 10              # OneCycleLR: Final lr as a percentage of max\n",
    "\n",
    "## Early stopping\n",
    "early_stopping_type = None         # Early stopping type ('marginals', 'mmd', or None)\n",
    "crit_evals          = 20           # Marginals: number of evaluations of criterion\n",
    "crit_thresh         = 0.975        # Marginals: stopping threshold\n",
    "cutoff              = 0.975        # Marginals: cutoff of marginal distributions to remove extremal values\n",
    "mmd_ci              = 0.95         # mmd: Confidence on null distribution\n",
    "mmd_atoms           = steps_lr     # mmd: number of atoms to build null distribution\n",
    "mmd_periods         = 50           # mmd: number of lagged periods to compare threshold to\n",
    "\n",
    "## Evaluation and plotting hyperparameters\n",
    "steps_per_print  = int(steps/10)              # How often to print the loss.\n",
    "num_plot_samples = int(batch_size/2)          # How many samples to use on the plots at the end.\n",
    "plot_locs        = (0.1, 0.3, 0.5, 0.7, 0.9)  # Plot some marginal distributions at this proportion of the way along."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. Data\n",
    "\n",
    "We now generate our training and testing data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "if data_type == \"forex\":\n",
    "    ## Real data regime\n",
    "    \n",
    "    data_kwargs = {\n",
    "        \"dataset_size\": dataset_size,\n",
    "        \"path_length\": path_length,\n",
    "        \"batch_size\": batch_size,\n",
    "        \"step_size\": stride_length,\n",
    "        \"learning_type\": \"paths\",\n",
    "        \"time_add_type\": time_add_type,\n",
    "        \"train_test_split\": tt_split\n",
    "    }\n",
    "\n",
    "    real_data_kwargs = {\n",
    "        \"pairs\": forex_pairs,\n",
    "        \"frequency\": frequency,\n",
    "        \"filter_extremal_paths\": filter_extremal_paths,\n",
    "        \"filter_extremal_pct\": filter_extremal_pct\n",
    "    }\n",
    "    \n",
    "    np_train_data, np_test_data = preprocess_real_data(data_kwargs, real_data_kwargs)\n",
    "        \n",
    "elif \"syn\" in data_type: \n",
    "    ## We are in the synthetic case\n",
    "    extract_size = int(2*dataset_size)\n",
    "    \n",
    "    if time_add_type == \"basic\":\n",
    "        times = np.linspace(0, path_length-1, path_length)\n",
    "    else:\n",
    "        times = np.linspace(0, path_length*dt, path_length)\n",
    "        \n",
    "    np_data         = np.zeros((extract_size, path_length, int(1 + output_dim)))\n",
    "    np_data[..., 0] = np.tile(times, (extract_size, 1))\n",
    "    \n",
    "    mu0, mu1, mu2    = drift_values\n",
    "    sig0, sig1, sig2 = diffusion_values\n",
    "        \n",
    "    ## Iterate over synthetic test types\n",
    "    if data_type == \"syn_type1\":\n",
    "        ## gbm_i -> gbm_i. \n",
    "\n",
    "        inds  = np.permutation(np.arange(dataset_size))\n",
    "        mask0 = inds[:int(extract_size/3)]\n",
    "        mask1 = inds[int(extract_size/3):int(2*extract_size/3)]\n",
    "        mask2 = inds[int(2*extract_size/3):]\n",
    "        \n",
    "        for i in range(2):\n",
    "            this_mask         = eval(f\"mask{i}\")\n",
    "            this_mu, this_sig = eval(f\"mu{i}\"), eval(f\"sig{i}\")\n",
    "            \n",
    "            np_data[this_mask, :, 1:] = np.expand_dims(process_generator(\n",
    "                sum(this_mask), dt*(path_length-1), path_length, this_mu, this_sig, S0, proc=process_type\n",
    "            ), axis=-1)\n",
    "            \n",
    "    else:\n",
    "        ## gbm -> gbm_{1, 2, 3}\n",
    "        \n",
    "        np_data[:, :cond_length, 1:] = np.expand_dims(process_generator(\n",
    "            extract_size, dt*(cond_length-1), cond_length, mu1, sig1, S0, proc=process_type\n",
    "        ), axis=-1)\n",
    "        \n",
    "        S0_ = np_data[:, cond_length-1, 1]\n",
    "        \n",
    "        # Non-Markovian example\n",
    "        mfunc = lambda x: np.power(x, 2) if syn2_mask_type == \"qv\" else np.abs(x)\n",
    "        \n",
    "        nm_condition = np.sum(mfunc(np.diff(np_data[:, :cond_length, 1], axis=1)), axis=1)\n",
    "        quantiles    = np.quantile(nm_condition, [0.33, 0.66])\n",
    "    \n",
    "        mask0 = nm_condition <= quantiles[0]\n",
    "        mask1 = (nm_condition <= quantiles[1])*(nm_condition > quantiles[0])\n",
    "        mask2 = nm_condition > quantiles[1]\n",
    "        \n",
    "        for i in range(3):\n",
    "            this_mask = eval(f\"mask{i}\")\n",
    "            this_mu   = eval(f\"mu{i}\")\n",
    "            this_sig  = eval(f\"sig{i}\")\n",
    "        \n",
    "            np_data[this_mask, cond_length-1:, 1:] = np.expand_dims(process_generator(\n",
    "                sum(this_mask), dt*(eval_length), eval_length+1, this_mu, this_sig, S0_[this_mask], proc=process_type\n",
    "            ), axis=-1)\n",
    "        \n",
    "    random_indexes = np.random.permutation(np.arange(np_data.shape[0]))\n",
    "    train_indexes  = random_indexes[:dataset_size]\n",
    "    test_indexes   = random_indexes[dataset_size:extract_size]\n",
    "\n",
    "    # Scale everything up here. \n",
    "    #np_data[..., 1] *= scale\n",
    "    \n",
    "    np_train_data = np_data[train_indexes].copy()\n",
    "    np_test_data  = np_data[test_indexes].copy()\n",
    "    \n",
    "else:\n",
    "    np_train_data, np_test_data = None, None\n",
    "    print(\"Data extraction type is not available\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if data_type == \"syn_type2\":\n",
    "    n_path_plots = 64\n",
    "    fig, axes = plt.subplots(1, 2, figsize=(10, 3))\n",
    "    ax1, ax2 = axes\n",
    "    make_grid(axis=ax1)\n",
    "    case1 = np_data[mask0, -1, 1]\n",
    "    case2 = np_data[mask1, -1, 1]\n",
    "    case3 = np_data[mask2, -1, 1]\n",
    "\n",
    "    ax1.hist(case1, bins=128, alpha=0.5, color= \"dodgerblue\" , density=True, label=\"mask_0\")\n",
    "    ax1.hist(case2, bins=128, alpha=0.5, color= \"slategrey\"  , density=True, label=\"mask_1\")\n",
    "    ax1.hist(case3, bins=128, alpha=0.5, color= \"tomato\"     , density=True, label=\"mask_2\")\n",
    "\n",
    "    ax1.legend();\n",
    "\n",
    "    make_grid(axis=ax2)\n",
    "    case1_paths    = np_data[mask0, :, 1][:n_path_plots]\n",
    "    case2_paths    = np_data[mask1, :, 1][:n_path_plots]\n",
    "    case3_paths    = np_data[mask2, :, 1][:n_path_plots] \n",
    "\n",
    "    for i in range(n_path_plots):\n",
    "        label1 = \"mask_0\" if i == 0 else \"\"\n",
    "        label2 = \"mask_1\" if i == 0 else \"\"\n",
    "        label3 = \"mask_2\" if i == 0 else \"\"\n",
    "        ax2.plot(case1_paths[i], color=\"dodgerblue\" , alpha=0.25, label=label1)\n",
    "        ax2.plot(case2_paths[i], color=\"slategrey\", alpha=0.25, label=label2)\n",
    "        ax2.plot(case3_paths[i], color=\"tomato\"    , alpha=0.25, label=label3)\n",
    "    ax2.legend()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Package into torch dataloader\n",
    "(tpast, ts), data_size, train_dataloader = get_real_data(\n",
    "    np_train_data, batch_size, dataset_size, device, time_add_type, time_add_round=4, filter_by_time=True,\n",
    "    split=change_pt, normalisation=normalisation, initial_point=False, scale=scale\n",
    ")\n",
    "\n",
    "_, _, test_dataloader = get_real_data(\n",
    "    np_test_data, batch_size, dataset_size, device, time_add_type, time_add_round=4, filter_by_time=True,\n",
    "    split=change_pt, normalisation=normalisation, initial_point=False, scale=scale\n",
    ")\n",
    "\n",
    "infinite_train_dataloader = (elem for it in iter(lambda: train_dataloader, None) for elem in it)\n",
    "\n",
    "cond_transformer   = Transformer(cond_transformations, cond_transformation_args, device).to(device) if do_transforms else lambda x: x\n",
    "out_transformer    = Transformer(out_transformations, out_transformation_args, device).to(device) if do_transforms else lambda x: x"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. Discriminator\n",
    "\n",
    "Here, we initialize the discriminator object."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "discriminator = SigKerScoreDiscriminator(\n",
    "    path_dim     = output_dim, \n",
    "    adversarial  = adversarial,\n",
    "    **discriminator_args\n",
    ").to(device)\n",
    "\n",
    "if adversarial:\n",
    "    for param in discriminator.parameters():\n",
    "        param = param*init_mult3"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. Generator\n",
    "\n",
    "As mentioned in the introduction, this initialises the conditional generator."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "conditioning_dim = data_size\n",
    "\n",
    "if cond_transformations[\"lead_lag\"]:\n",
    "    if cond_transformation_args[\"lead_lag\"][\"time_out\"]:\n",
    "        conditioning_dim += 1\n",
    "\n",
    "extra_config = {k: conditional_config[k] for k in (\"logsig\", \"order\", \"sig_scale\")}\n",
    "total_config = {**generator_config, **extra_config, **{\"conditioning_dim\" : conditioning_dim}}\n",
    "\n",
    "generator    = PathConditionalSigGenerator(data_size=data_size, **total_config).to(device)\n",
    "    \n",
    "with torch.no_grad():            \n",
    "    for name, prm in generator._func.named_parameters():\n",
    "        if \"_drift\" in name:\n",
    "            prm.data*=init_mult2_dr\n",
    "        else:\n",
    "            prm.data*=init_mult2_df"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4. Optimisers and Annealers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "averaged_generator     = swa_utils.AveragedModel(generator)\n",
    "averaged_discriminator = swa_utils.AveragedModel(discriminator)\n",
    "\n",
    "generator_optimiser_    = getattr(torch.optim, gen_optim)\n",
    "generator_optimiser     = generator_optimiser_(generator.parameters(), lr=generator_lr, weight_decay=weight_decay)\n",
    "generator_optimiser.zero_grad()\n",
    "\n",
    "if adversarial:\n",
    "    discriminator_optimiser = getattr(torch.optim, disc_optim)(discriminator.parameters(), lr=discriminator_lr, weight_decay=weight_decay)\n",
    "    discriminator_optimiser.zero_grad()\n",
    "else:\n",
    "    discriminator_optimiser = None\n",
    "        \n",
    "if adapting_lr:\n",
    "    if adapting_lr_type == \"LambdaLR\":\n",
    "        f_lmd = lambda epoch: (1-lambda_lr_const)*np.power(epoch + 1., poly_exponent_smoother) + lambda_lr_const\n",
    "        adpt_kwargs = {\"lr_lambda\": f_lmd}\n",
    "    elif adapting_lr_type == \"StepLR\":\n",
    "        adpt_kwargs = {\"gamma\": gamma_lr, \"step_size\":steps_lr}\n",
    "    elif adapting_lr_type == \"MultiplicativeLR\":\n",
    "        adpt_kwargs = {\"lr_lambda\": lambda epoch: mult_const}\n",
    "        \n",
    "    elif adapting_lr_type == \"OneCycleLR\":\n",
    "        adpt_kwargs = {\n",
    "            \"max_lr\": max_lr, \n",
    "            \"total_steps\": total_steps, \n",
    "            \"pct_start\": pct_start, \n",
    "            \"anneal_strategy\": anneal_strategy,\n",
    "            \"div_factor\": div_factor\n",
    "        }\n",
    "    \n",
    "    g_scheduler, d_scheduler = get_scheduler(\n",
    "        generator_optimiser, \n",
    "        discriminator_optimiser, \n",
    "        adapting_lr_type, \n",
    "        adversarial,\n",
    "        **adpt_kwargs\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 5. Pre-checks \n",
    "\n",
    "These are optional. You can turn specific ones off/on with the variables below. We provide an explanation of each pre-check alongside it. \n",
    "\n",
    "Note that if the <code>RUN_PRE_CHECKS</code> global variable is set to False, none of the checks will run.\n",
    "\n",
    "### 5.1 Initial distribution of (conditional) generated paths.\n",
    "\n",
    "To help facilitate learning, one would like the variance of the initially generated paths to match the real-data paths. The following gives a visual indiciation as to whether this is the case or not. One can adjust the degree of variance with the <code>init_mult2_df</code> parameter in the optimiser section of the configuration. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Check the initial distribution to calibrate size of initial vector fields (makes training easier!)\n",
    "if RUN_PRE_CHECKS:\n",
    "\n",
    "    fig, ax = plt.subplots(1, 1, figsize=(6, 3))\n",
    "\n",
    "    n_plot_samples = 32\n",
    "\n",
    "    cond_samples, true_samples = next(infinite_train_dataloader)\n",
    "\n",
    "    with torch.no_grad():\n",
    "        cond_samples      = cond_transformer(subtract_initial_point(cond_samples))\n",
    "        generated_samples = batch_subtract_initial_point(generator(ts, cond_samples, emp_size))\n",
    "        true_samples      = subtract_initial_point(true_samples)\n",
    "\n",
    "        generated_plot_samples = generated_samples[:n_plot_samples, 0, :, 1].cpu()\n",
    "        true_plot_samples      = true_samples[:n_plot_samples, :, 1].cpu()\n",
    "\n",
    "        generated_first = True\n",
    "        real_first      = True\n",
    "\n",
    "        for xi, yi in zip(generated_plot_samples, true_plot_samples):\n",
    "            g_kwargs = {\"label\": \"generated\"} if real_first else {}\n",
    "            r_kwargs = {\"label\": \"real\"} if generated_first else {}\n",
    "\n",
    "            plt.plot(xi, color=\"tomato\", alpha=0.25, **g_kwargs)\n",
    "            plt.plot(yi, color=\"dodgerblue\", alpha=0.25, **r_kwargs)\n",
    "\n",
    "            generated_first = False\n",
    "            real_first      = False\n",
    "\n",
    "        plt.legend()\n",
    "        plt.title(\"Initialisation of $G$ against real data\")\n",
    "        make_grid()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 5.2 Moment ratio values. \n",
    "\n",
    "Suppose $y: [t, T] \\to \\mathbb{R}^d$ is an observed path conditional on $x \\in \\mathcal{C}([0, t]; \\mathbb{R}^d)$, so $y \\sim \\mathbb{P}_{X^{\\text{true}}}(\\cdot|x)$. \n",
    "\n",
    "For learning with the signature kernel, it helps to check the distribution of the moments of the random variable $\\Delta y = y_T - y_0$. As the signature kernel between two paths is given by\n",
    "\n",
    "\\begin{equation*}\n",
    "    k_{\\mathrm{sig}}(x, y) = \\sum_{k\\ge 0} \\sum_{i_1, \\dots, i_k \\in \\{1, 2, \\dots, d\\}^k} S(x)^{i_1, \\dots, i_k} \\cdot S(y)^{i_1, \\dots, i_k},\n",
    "\\end{equation*}\n",
    "\n",
    "and, for $i_1, \\dots, i_k \\in \\{1, \\dots, d\\}^k$ where $i_1 = i_2 = \\dots = i_k$,\n",
    "\n",
    "\\begin{equation*}\n",
    "    S(y)^{i_1, \\dots, i_k} = \\int_{t \\le t_k \\le T} \\dots \\int_{t \\le t_1 \\le t_2} dy^{i_1}_{t_1} \\dots dy^{i_k}_{t_k} = \\frac{(\\Delta y^{i_k})^k}{k!},\n",
    "\\end{equation*}\n",
    "\n",
    "so we approximate the magnitude of each level via the increment of $y$ along a given spatial dimension. These values are of interest as (we theorize) the direction of steepest descent when backpropagating through the batched scoring rule will be given by that which reduces the impact coming from the largest terms. This is why (in practice) the generator tends to learn the distribution of the marginal increments first, then their variances, and so on."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if RUN_PRE_CHECKS:\n",
    "    sigma  = discriminator_args.get(\"sigma\")\n",
    "    kernel = discriminator_args.get(\"kernel_type\")\n",
    "    \n",
    "    sigma_ = sigma[0] if type(sigma) != float else sigma\n",
    "    \n",
    "    pow_func = lambda x, y: torch.pow(x, y)/math.factorial(y)\n",
    "    powers = np.arange(1, 10).astype(int)\n",
    "    \n",
    "    if \"syn\" in data_type:\n",
    "        for i in range(3):\n",
    "            this_mask         = eval(f\"mask{i}\")\n",
    "            _scale            = scale**2 if kernel == \"rbf\" else scale\n",
    "            these_paths       = np_data[this_mask, :, 1].copy()*(_scale/scale)\n",
    "            these_increments  = these_paths[:, -1] - these_paths[:, 0]\n",
    "            scaled_increments = these_increments/np.sqrt(sigma_) if kernel == \"rbf\" else these_increments\n",
    "            \n",
    "            scaled_increments = torch.tensor(scaled_increments)\n",
    "            \n",
    "            moment_terms = torch.tensor([[pow_func(inc, d) for d in powers] for inc in scaled_increments])\n",
    "            moment_means = moment_terms.mean(axis=0)\n",
    "            moment_stds  = moment_terms.std(axis=0)\n",
    "            \n",
    "            plot_line_error_bars(moment_means, moment_stds, figsize=(6, 3), powers=powers)\n",
    "            \n",
    "    else:\n",
    "                 \n",
    "        _, paths  = next(infinite_train_dataloader)\n",
    "    \n",
    "        increments        = torch.abs(paths[:, -1, 1] - paths[:, 0, 1])\n",
    "        scaled_increments = increments/sigma_ if kernel == \"rbf\" else increments\n",
    "    \n",
    "        moment_terms = torch.tensor([[pow_func(inc, d) for d in powers] for inc in scaled_increments])\n",
    "        moment_means = moment_terms.mean(axis=0)\n",
    "        moment_stds  = moment_terms.std(axis=0)\n",
    "    \n",
    "        plot_line_error_bars(moment_means, moment_stds, figsize=(6, 3), powers=powers)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 5.3 Scoring rule values. \n",
    "\n",
    "The following checks help give some intuition as to when training might be failing, in the synthetic case. \n",
    "\n",
    "Although properness is always guaranteed, it might be that the loss from guessing the \"middle\" distribution in the synthetic example is comparable to the \"correct\" guess. As a result, the conditional generator can sometimes learn to just output this distribution instead of correctly learning the true conditional distribution. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if RUN_PRE_CHECKS and \"syn\" in data_type:\n",
    "    \n",
    "    norm_data          = normalize(np_data.copy(), normalisation)\n",
    "    n_atoms            = 128\n",
    "    _emp_size          = emp_size\n",
    "    lambd_             = 1e0\n",
    "    \n",
    "    def compare_vectors(A, B):\n",
    "        y = B.view(-1, 1)\n",
    "        x = A.view(1, -1)\n",
    "        y, x = torch.broadcast_tensors(y, x)\n",
    "        \n",
    "        return torch.any(y < x)\n",
    "    \n",
    "    def get_scoring_rule_distribution(X, Y, n_tests, batch_size, discriminator):\n",
    "        x_size, _, _ = X.shape\n",
    "        y_size, _, _ = Y.shape\n",
    "        \n",
    "        res = torch.zeros(n_atoms).to(device)\n",
    "        \n",
    "        x_ints = torch.randint(0, x_size, size=(n_tests, batch_size))\n",
    "        y_ints = torch.randint(0, y_size, size=(n_tests, 1))\n",
    "        \n",
    "        for i in tqdm(range(n_tests)):\n",
    "            \n",
    "            _X = X[x_ints[i]]\n",
    "            _y = Y[y_ints[i]][0]\n",
    "            \n",
    "            res[i] = discriminator(_X, _y)\n",
    "                \n",
    "        return res\n",
    "            \n",
    "    with torch.no_grad():\n",
    "        ptheta1 = torch.tensor(norm_data[mask0, change_pt:, :], dtype=torch.float64).to(device)\n",
    "        ptheta2 = torch.tensor(norm_data[mask2, change_pt:, :], dtype=torch.float64).to(device)\n",
    "        ptheta3 = torch.tensor(norm_data[mask1, change_pt:, :], dtype=torch.float64).to(device)\n",
    "\n",
    "        ptheta1 = out_transformer(ptheta1)*lambd_\n",
    "        ptheta2 = out_transformer(ptheta2)*lambd_\n",
    "        ptheta3 = out_transformer(ptheta3)*lambd_\n",
    "\n",
    "        if subtract_start:\n",
    "            ptheta1 = subtract_initial_point(ptheta1)\n",
    "            ptheta2 = subtract_initial_point(ptheta2)\n",
    "            ptheta3 = subtract_initial_point(ptheta3)\n",
    "\n",
    "        # Check distributions of scoring rules\n",
    "        sXX = get_scoring_rule_distribution(ptheta1, ptheta1, n_atoms, _emp_size, discriminator)\n",
    "        sYY = get_scoring_rule_distribution(ptheta2, ptheta2, n_atoms, _emp_size, discriminator)\n",
    "        sZZ = get_scoring_rule_distribution(ptheta3, ptheta3, n_atoms, _emp_size, discriminator)\n",
    "        \n",
    "        sXY = get_scoring_rule_distribution(ptheta1, ptheta2, n_atoms, _emp_size, discriminator)\n",
    "        sYX = get_scoring_rule_distribution(ptheta2, ptheta1, n_atoms, _emp_size, discriminator)\n",
    "        sZX = get_scoring_rule_distribution(ptheta3, ptheta1, n_atoms, _emp_size, discriminator)\n",
    "        sXZ = get_scoring_rule_distribution(ptheta1, ptheta3, n_atoms, _emp_size, discriminator)\n",
    "        sZY = get_scoring_rule_distribution(ptheta3, ptheta2, n_atoms, _emp_size, discriminator)\n",
    "        sYZ = get_scoring_rule_distribution(ptheta2, ptheta3, n_atoms, _emp_size, discriminator)\n",
    "        \n",
    "        # Plot results\n",
    "\n",
    "        fig, ax = plt.subplots(2, 1, figsize=(8, 6))\n",
    "        h0_dists = [sXX, sYY, sZZ]\n",
    "        h1_dists = [sYX, sZX, sXY, sZY, sXZ, sYZ]\n",
    "        h0_labels = [\"$S(p_1, x)$\", \"$S(p_2, y)$\", \"$S(p_3, z)$\"]\n",
    "        h1_labels = [\"$S(p_2, x)$\", \"$S(p_3, x)$\", \"$S(p_1, y)$\", \"$S(p_3, y)$\", \"$S(p_1, z)$\", \"$S(p_2, z)$\"]\n",
    "        \n",
    "        n_bins = int(n_atoms/10)\n",
    "        for lab, dist in zip(h0_labels, h0_dists):\n",
    "            ax[0].hist(dist.cpu().numpy(), bins=n_bins, alpha=0.5, density=True, label=lab)\n",
    "        ax[0].set_title(\"Distribution of scoring rules, null hypothesis\")\n",
    "        make_grid(axis=ax[0])\n",
    "        ax[0].legend()\n",
    "        \n",
    "        for lab, dist in zip(h1_labels, h1_dists):\n",
    "            ax[1].hist(dist.cpu().numpy(), bins=n_bins, alpha=0.5, density=True, label=lab)\n",
    "        ax[1].set_title(\"Distribution of scoring rules, alternate hypothesis\")\n",
    "        make_grid(axis=ax[1])\n",
    "        ax[1].legend()    \n",
    "        \n",
    "    plt.tight_layout()\n",
    "        \n",
    "    h0_means = torch.tensor([torch.mean(dist) for dist in h0_dists])\n",
    "    h1_means = torch.tensor([torch.mean(dist) for dist in h1_dists])\n",
    "    z_means  = torch.tensor([torch.mean(dist) for dist in [sZX, sZY, sZZ]])\n",
    "\n",
    "    for lab, score in zip(h0_labels, h0_means):\n",
    "        print(f\"H0: Expected scoring rule of {lab[1:-1]}: {score:.4f}\")\n",
    "\n",
    "    print(\"\\n\")\n",
    "\n",
    "    for lab, score in zip(h1_labels, h1_means):\n",
    "        print(f\"H1: Expected scoring rule of {lab[1:-1]}: {score:.4f}\")\n",
    "\n",
    "    if compare_vectors(h0_means, h1_means):\n",
    "        print(\"\\nNot all null distributions are dominated by alternate distributions.\")\n",
    "    else:\n",
    "        print(\"\\nAll null distributions are dominated by alternate distributions.\")\n",
    "    ideal_batch_loss = torch.mean(h0_means)\n",
    "    print(f\"\\nIdealised batch loss: {ideal_batch_loss:.4f}\")\n",
    "    \n",
    "    print(f\"Loss from guessing p3 always: {torch.mean(z_means):.4f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 5.4 Distribution of conditions.\n",
    "\n",
    "In the case of the signature conditional generator, it can help to check the distribution of the conditional variables (in the synthetic example)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if RUN_PRE_CHECKS and (\"syn\" in data_type):\n",
    "    \n",
    "    h0_conds = norm_data[mask0, :change_pt, :]\n",
    "    h1_conds = norm_data[mask2, :change_pt, :]\n",
    "    \n",
    "    # Get conditional paths\n",
    "    h0_ints = np.random.choice(np.arange(h0_conds.shape[0]), batch_size)\n",
    "    h1_ints = np.random.choice(np.arange(h1_conds.shape[0]), batch_size)\n",
    "\n",
    "    h0_c = cond_transformer(torch.tensor(h0_conds[h0_ints]).to(device))\n",
    "    h1_c = cond_transformer(torch.tensor(h1_conds[h1_ints]).to(device))\n",
    "    \n",
    "    h0_cond_sigs = generator._signature(h0_c)\n",
    "    h1_cond_sigs = generator._signature(h1_c)\n",
    "    \n",
    "    with torch.no_grad():\n",
    "        word_func = signatory.lyndon_words if conditional_config[\"logsig\"] else signatory.all_words\n",
    "        words     = word_func(conditional_config[\"order\"], 2)\n",
    "        n_bins = int(batch_size/10)\n",
    "        h0_sig_terms = h0_cond_sigs.cpu().numpy().T\n",
    "        h1_sig_terms = h1_cond_sigs.cpu().numpy().T\n",
    "\n",
    "        fig, axes = plt.subplots(1, h0_sig_terms.shape[0], figsize=(12, 4))\n",
    "        \n",
    "        for i, (ax, s1, s2) in enumerate(zip(axes, h0_sig_terms, h1_sig_terms)):\n",
    "            ax.hist(s1, bins=n_bins, density=True, color=\"dodgerblue\", label=\"h0\", alpha=0.5)\n",
    "            ax.hist(s2, bins=n_bins, density=True, color=\"tomato\",     label=\"h1\", alpha=0.5)\n",
    "            ax.legend()\n",
    "            ax.set_title(words[i])\n",
    "            make_grid(axis=ax)\n",
    "plt.tight_layout()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 6. Train the GAN \n",
    "\n",
    "We are now ready to train the conditional GAN."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "# Final training configurations\n",
    "TRAIN_MODEL            = True   # Train the model or load a pre-existing one.\n",
    "\n",
    "update_plots           = True   # See how training is going in real time.\n",
    "num_plot_samples       = 2      # Number of conditional plot samples to display\n",
    "emp_plot_size          = 16     # Size of empirical measure\n",
    "calculate_dataset_loss = False  # Calculate loss over the entire dataset at the print stage. Can be quite slow!\n",
    "\n",
    "gen_fp  = get_project_root().as_posix() + f\"/saved_models/generators/{data_type}/\"\n",
    "disc_fp = get_project_root().as_posix() + f\"/saved_models/discriminators/{data_type}/\"\n",
    "\n",
    "step_vec = np.arange(steps)\n",
    "sigmas   = torch.zeros(steps)\n",
    "loss_kwargs = {\"emp_size\": emp_size, \"subtract_start\": subtract_start}\n",
    "\n",
    "if TRAIN_MODEL:\n",
    "    tr_loss    = torch.zeros(steps, requires_grad=False).to(device)\n",
    "    trange     = tqdm(range(steps), position=0, leave=True)\n",
    "    criterions = []\n",
    "\n",
    "    for step in trange:\n",
    "        \n",
    "        ###############################################################################\n",
    "        ## Calculate loss\n",
    "        ###############################################################################\n",
    "        \n",
    "        cond_samples, true_samples  = next(infinite_train_dataloader)\n",
    "\n",
    "        loss = calculate_batch_conditional_scoring_loss(\n",
    "            ts,\n",
    "            discriminator, \n",
    "            generator,\n",
    "            batch_size,\n",
    "            cond_samples, \n",
    "            true_samples, \n",
    "            cond_transformer,\n",
    "            out_transformer,\n",
    "            **loss_kwargs\n",
    "        )\n",
    "                \n",
    "        loss.backward()\n",
    "\n",
    "        tr_loss[step] += loss.clone().detach()\n",
    "        \n",
    "        ###############################################################################\n",
    "        ## Plotting temporal results\n",
    "        ###############################################################################\n",
    "        if update_plots and ((step % steps_per_print == 0) or (step == steps - 1)):\n",
    "            fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 10))\n",
    "            # Updating loss plot\n",
    "            with torch.no_grad():\n",
    "                np_tr_loss = tr_loss.cpu().numpy()\n",
    "            \n",
    "            ax1.plot(step_vec[:step], np_tr_loss[:step], alpha=1., color=\"dodgerblue\", label=\"training_loss\")\n",
    "            ax1.plot(step_vec[step:], np_tr_loss[step:] + min(np_tr_loss), alpha=0.)\n",
    "            ax1.grid(visible=True, color='grey', linestyle=':', linewidth=1.0, alpha=0.3)\n",
    "            ax1.minorticks_on()\n",
    "            ax1.grid(visible=True, which='minor', color='grey', linestyle=':', linewidth=1.0, alpha=0.1)\n",
    "            ax1.set_title(\"Training loss\")\n",
    "            \n",
    "            # Plotting path examples\n",
    "            \n",
    "            if subtract_start:\n",
    "                true_samples = subtract_initial_point(true_samples)\n",
    "                cond_samples = subtract_initial_point(cond_samples)\n",
    "                \n",
    "            true_times        = true_samples[..., 0].detach().cpu()\n",
    "            cond_times        = cond_samples[..., 0].detach().cpu()\n",
    "            true_plot_times   = true_times[:num_plot_samples].detach()\n",
    "            cond_plot_times   = cond_times[:num_plot_samples].detach()\n",
    "            \n",
    "            \n",
    "            true_plot_samples = true_samples[:num_plot_samples, :, 1:].detach().cpu()\n",
    "            cond_plot_samples = cond_samples[:num_plot_samples, :, 1:].detach().cpu()\n",
    "            \n",
    "            # More plotting business\n",
    "            if subtract_start: \n",
    "                stpts_matrix = torch.tile(cond_plot_samples[:, -1, 0].unsqueeze(-1).unsqueeze(-1), (1, eval_length, 1))\n",
    "            else:\n",
    "                stpts_matrix = torch.zeros((batch_size, eval_length, data_size))\n",
    "                \n",
    "            tps                   = stpts_matrix.detach().cpu()\n",
    "            true_plot_samples_adj = tps.clone() + true_plot_samples\n",
    "            plot_ts               = ts.detach().cpu()\n",
    "            plot_ts              -= torch.diff(plot_ts)[0]\n",
    "            \n",
    "            real_first      = True\n",
    "            cond_first      = True\n",
    "            generated_first = True\n",
    "            \n",
    "            for i, cond_sample_ in enumerate(cond_plot_samples):\n",
    "                kwargs = {'label': 'Conditioning paths'} if cond_first else {}\n",
    "                ax2.plot(tpast.cpu(), cond_sample_, color='slategrey', linewidth=0.5, alpha=0.7, linestyle=\"dotted\", **kwargs)\n",
    "                cond_first = False\n",
    "                \n",
    "                with torch.no_grad():\n",
    "                    generated_samples = generator(ts, cond_transformer(cond_samples[i].unsqueeze(0)), emp_plot_size)[0]\n",
    "                    \n",
    "                    if subtract_start:\n",
    "                        generated_samples = subtract_initial_point(generated_samples)\n",
    "                        generated_samples[..., 1] += torch.tile(cond_sample_[-1, 0], (emp_plot_size, eval_length, 1)).to(device)[..., 0]\n",
    "\n",
    "                    generated_samples = generated_samples[..., 1].cpu()\n",
    "\n",
    "                    for gen_sample_ in generated_samples:\n",
    "                        kwargs = {'label': 'Generated'} if generated_first else {}\n",
    "                        ax2.plot(plot_ts.cpu(), gen_sample_, color='crimson', linewidth=0.5, alpha=0.2, **kwargs)\n",
    "                        generated_first = False\n",
    "                    \n",
    "            for true_sample_ in true_plot_samples_adj:\n",
    "                kwargs = {'label': 'Real'} if real_first else {}\n",
    "                ax2.plot(plot_ts, true_sample_, color='black', linestyle=\"dashed\", linewidth=0.5, alpha=0.9, **kwargs)\n",
    "                real_first = False\n",
    "                \n",
    "            ax2.axvline(plot_ts[0], color=\"slategrey\", linestyle=\"dashed\", alpha=0.5)\n",
    "            ax2.legend()\n",
    "            ax2.grid(visible=True, color='grey', linestyle=':', linewidth=1.0, alpha=0.3)\n",
    "            ax2.minorticks_on()\n",
    "            ax2.grid(visible=True, which='minor', color='grey', linestyle=':', linewidth=1.0, alpha=0.1)\n",
    "            ax2.set_title(\"Conditional distributions against true\")\n",
    "\n",
    "            display.clear_output(wait=True)\n",
    "            display.display(plt.gcf())\n",
    "            \n",
    "        ###############################################################################\n",
    "        ## Step through optimisers and adapting LR schedulers, stochastic weights\n",
    "        ###############################################################################\n",
    "        \n",
    "        if adversarial:\n",
    "            for param in discriminator.parameters():\n",
    "                param.grad *= -1\n",
    "\n",
    "            discriminator_optimiser.step()\n",
    "            discriminator_optimiser.zero_grad()\n",
    "            \n",
    "            with torch.no_grad():\n",
    "                \n",
    "                if clip_disc_param:\n",
    "                    for param in discriminator.parameters():\n",
    "                        param.clamp_(1, 1e2)\n",
    "                        \n",
    "                sigmas[step] = discriminator._sigma.item()\n",
    "            \n",
    "            if adapting_lr:\n",
    "                d_scheduler.step()\n",
    "\n",
    "        generator_optimiser.step()\n",
    "        generator_optimiser.zero_grad()\n",
    "        \n",
    "        if adapting_lr:\n",
    "            g_scheduler.step()\n",
    "\n",
    "        # Stochastic weight averaging of generator (and discriminator, doesn't matter when not adversarial)\n",
    "        if step > swa_step_start:\n",
    "            averaged_generator.update_parameters(generator)\n",
    "            averaged_discriminator.update_parameters(discriminator)\n",
    "        \n",
    "        ###############################################################################\n",
    "        ## Batched loss calculation\n",
    "        ###############################################################################\n",
    "        if (step % steps_per_print) == 0 or step == steps - 1:     \n",
    "            # Print total loss on dataset  \n",
    "            if calculate_dataset_loss:\n",
    "                total_unaveraged_loss = evaluate_conditional_scoring_loss(\n",
    "                    ts, batch_size, discriminator, generator, train_dataloader, cond_transformer, out_transformer, **loss_kwargs\n",
    "                )\n",
    "            else:\n",
    "                total_unaveraged_loss = loss.item()\n",
    "            \n",
    "            if step > swa_step_start:\n",
    "                if calculate_dataset_loss:\n",
    "                    total_averaged_loss = evaluate_conditional_scoring_loss(\n",
    "                        ts, batch_size, averaged_discriminator, averaged_generator.module, train_dataloader, cond_transformer, out_transformer, **loss_kwargs\n",
    "                    )    \n",
    "\n",
    "                else:\n",
    "                    total_averaged_loss = loss.item()\n",
    "\n",
    "                trange.write(f\"Step: {step:3} Loss (unaveraged): {total_unaveraged_loss:.5e} \"\n",
    "                             f\"Loss (averaged): {total_averaged_loss:.5e} \")\n",
    "            else:\n",
    "                trange.write(f\"Step: {step:3} Loss (unaveraged): {total_unaveraged_loss:.5e} \")\n",
    "    \n",
    "    ###############################################################################\n",
    "    ## Training complete\n",
    "    ###############################################################################\n",
    "    mkdir(gen_fp)\n",
    "    mkdir(disc_fp)\n",
    "    torch.save(generator.state_dict(), gen_fp + f\"path_cond_generator.pkl\")\n",
    "    torch.save(discriminator.state_dict(), disc_fp + f\"path_cond_discriminator.pkl\")\n",
    "    \n",
    "    torch.save(generator_config, gen_fp + \"config.pkl\")\n",
    "    torch.save(discriminator_args, disc_fp + \"config.pkl\")\n",
    "    \n",
    "    torch.save(averaged_generator.state_dict(), gen_fp + \"path_cond_generator_averaged.pkl\")\n",
    "    torch.save(averaged_discriminator.state_dict(), disc_fp + \"path_cond_discriminator_averaged.pkl\")\n",
    "\n",
    "    plot_loss(tr_loss)\n",
    "else:\n",
    "    try:\n",
    "        generator_state_dict     = torch.load(gen_fp + \"path_cond_generator.pkl\")\n",
    "        discriminator_state_dict = torch.load(disc_fp + \"path_cond_discriminator.pkl\")\n",
    "\n",
    "        averaged_generator_state_dict     = torch.load(gen_fp + \"path_cond_generator_averaged.pkl\")\n",
    "        averaged_discriminator_state_dict = torch.load(disc_fp + \"path_cond_discriminator_averaged.pkl\")\n",
    "    except FileNotFoundError: \n",
    "        print(\"Model needs to be trained first. Please set TRAIN_MODEL to True.\")\n",
    "    \n",
    "    generator.load_state_dict(generator_state_dict)\n",
    "    discriminator.load_state_dict(discriminator_state_dict)\n",
    "    \n",
    "    averaged_generator.load_state_dict(averaged_generator_state_dict)\n",
    "    averaged_discriminator.load_state_dict(averaged_discriminator_state_dict)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 7. Evaluation \n",
    "\n",
    "We use this section to evaluate the training instance, along with the performance of the conditional GAN."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if TRAIN_MODEL:\n",
    "\n",
    "    fig, ax = plt.subplots(figsize=(10, 10))\n",
    "\n",
    "    with torch.no_grad():\n",
    "        np_tr_loss = tr_loss.cpu().numpy()\n",
    "\n",
    "    ax.plot(step_vec[:step], np_tr_loss[:step], alpha=1., color=\"dodgerblue\", label=\"training_loss\")\n",
    "    if adversarial:\n",
    "        ax2 = ax.twinx()\n",
    "        ax2.plot(sigmas[:step], color=\"grey\", label=\"discriminator_sigmas\", alpha=0.75)\n",
    "    #ax.plot(step_vec[step:], np_tr_loss[step:], alpha=0.)\n",
    "    ax.grid(visible=True, color='grey', linestyle=':', linewidth=1.0, alpha=0.3)\n",
    "    ax.minorticks_on()\n",
    "    ax.grid(visible=True, which='minor', color='grey', linestyle=':', linewidth=1.0, alpha=0.1)\n",
    "    extra_title = \" against discriminator sigma\" if adversarial else \"\"\n",
    "    ax.set_title(\"Training loss\" + extra_title)\n",
    "    ax.legend()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Synthetic example\n",
    "n_evaluation_samples = 1\n",
    "num_eval_paths       = 32\n",
    "\n",
    "if data_type.lower() != \"forex\":\n",
    "    \n",
    "    # 1. Subsample from the conditioning paths and construct true distributions\n",
    "    eval_indexes    = torch.randperm(batch_size)[:n_evaluation_samples]\n",
    "    eval_cond_paths = cond_samples[eval_indexes].detach().cpu()/scale\n",
    "    eval_true_paths = true_samples[eval_indexes].detach().cpu()/scale\n",
    "    \n",
    "    # 2. Un-normalise (as class tagging was done on the un-normalized paths)\n",
    "    if normalisation is not None:\n",
    "        val1, val2 = get_scalings(np_train_data[..., 1:], normalisation)\n",
    "        val1 = torch.tensor(val1)\n",
    "        val2 = torch.tensor(val2)\n",
    "        \n",
    "        norm_cond_paths = inv_normalize(eval_cond_paths, normalisation, val1=val1, val2=val2)\n",
    "        norm_true_paths = inv_normalize(eval_true_paths, normalisation, val1=val1, val2=val2)\n",
    "    else:\n",
    "        norm_cond_paths = eval_cond_paths\n",
    "        norm_true_paths = eval_true_paths\n",
    "    \n",
    "    # 3. Loop over conditional samples and plot distributions against true\n",
    "    fig, ax = plt.subplots(1, 1, figsize=(5, 5))\n",
    "    make_grid(axis=ax)\n",
    "    \n",
    "    ts_cond = tpast.cpu()\n",
    "    ts_true = ts.cpu() - 1\n",
    "    \n",
    "    gen_first  = True\n",
    "    cf_first   = True\n",
    "    cond_first = True\n",
    "    true_first = True\n",
    "    \n",
    "    for cp, tp in zip(norm_cond_paths, norm_true_paths):\n",
    "        # Check what class this path belongs to\n",
    "        \n",
    "        cp += S0\n",
    "        \n",
    "        nm_val = mfunc(np.diff(cp[:, 1])).sum()\n",
    "        if nm_val < quantiles[0]:\n",
    "            class_ = 0\n",
    "        elif (nm_val >= quantiles[0]) and (nm_val < quantiles[1]):\n",
    "            class_ = 1\n",
    "        else:\n",
    "            class_ = 2\n",
    "\n",
    "        # Generate bank of counterfactuals\n",
    "        this_mu, this_sig = eval(f\"mu{int(class_)}\"), eval(f\"sig{int(class_)}\")\n",
    "        this_S0           = torch.tile(cp[-1, 1], (num_eval_paths, 1)).flatten().cpu().numpy()\n",
    "              \n",
    "        this_mu  = np.array(this_mu)\n",
    "        this_sig = np.array(this_sig)\n",
    "        cfs      = process_generator(\n",
    "            num_eval_paths, dt*eval_length, eval_length, this_mu, this_sig, this_S0, proc=process_type\n",
    "        )\n",
    "        \n",
    "        # Generate a counterfactual\n",
    "        cp_t     = torch.clone(cp)\n",
    "        \n",
    "        if normalisation is not None:\n",
    "            cp_t = normalize(cp_t, normalisation, val1=val1, val2=val2)\n",
    "            \n",
    "        cp_t[..., 1] = cp[..., 1] * scale\n",
    "        cp_t         = cond_transformer(cp.unsqueeze(0))\n",
    "        \n",
    "        if subtract_initial_point:\n",
    "            cp_t = subtract_initial_point(cp_t)\n",
    "        \n",
    "        gen_cfs = generator(ts, cp_t, num_eval_paths)[0].cpu()/scale\n",
    "        \n",
    "        if normalisation is not None:\n",
    "            gen_cfs = inv_normalize(gen_cfs, normalisation, val1=val1, val2=val2)\n",
    "        \n",
    "        if subtract_initial_point:\n",
    "            gen_cfs      = subtract_initial_point(gen_cfs)\n",
    "            tp           = tp[:, 1] - tp[0, 1]\n",
    "            plot_gen_cfs = gen_cfs[..., 1].detach().cpu().numpy() + cp[-1, 1].item()\n",
    "            tp           = tp + cp[-1, 1].item()\n",
    "            \n",
    "        plot_cond = cp[..., 1].cpu()\n",
    "        \n",
    "        # Plot the results\n",
    "        label = \"$\\text{conditioning_path}$\" if cond_first else \"\"\n",
    "        ax.plot(ts_cond, plot_cond, color=\"slategrey\", alpha=0.3, label=label)\n",
    "        cond_first = False\n",
    "        for cfp in cfs:\n",
    "            label = \"$\\text{true_counterfactual}$\" if cf_first else \"\"\n",
    "            ax.plot(ts_true, cfp, color=\"dodgerblue\", alpha=0.1, label=label)\n",
    "            cf_first = False\n",
    "        for pgcf in plot_gen_cfs:\n",
    "            label = \"$\\text{generated_counterfactual}$\" if gen_first else \"\"\n",
    "            ax.plot(ts_true, pgcf, color=\"tomato\", alpha=0.1, label=label)\n",
    "            gen_first = False\n",
    "        ax.plot(ts_true, tp, color=\"black\", linestyle=\"dashed\", alpha=0.5, label=\"true_path\" if true_first else \"\")\n",
    "        true_first = False\n",
    "        ax.set_title(\"Distribution of conditional generator vs true counterfactuals\")\n",
    "    ax.legend()\n",
    "else:\n",
    "    # 1. Generate examples\n",
    "    path_types = [\"lowest\", \"random\", \"extreme\"]\n",
    "    fig, axes = plt.subplots(1, 3, figsize=(15, 5))\n",
    "    \n",
    "    \n",
    "    for ax, path_type in zip(axes, path_types):\n",
    "        with torch.no_grad():\n",
    "            x, y    = next(iter(test_dataloader))\n",
    "            t_x     = cond_transformer(subtract_initial_point(x.clone()))\n",
    "            gen_y   = generator(ts, t_x, num_eval_paths).cpu()\n",
    "            t_gen_y = batch_subtract_initial_point(gen_y)\n",
    "            t_y     = subtract_initial_point(y.clone())\n",
    "\n",
    "            # 2. Pick the highest and lowest qv input paths\n",
    "            _x    = subtract_initial_point(x)\n",
    "            qvs   = torch.sum(torch.pow(torch.diff(_x[..., 1], axis=1), 2), axis=1)\n",
    "            s_qvs = torch.argsort(qvs)\n",
    "\n",
    "            if path_type == \"extreme\":\n",
    "                inds_ = [s_qvs[-i] for i in range(1, n_evaluation_samples + 1)]\n",
    "            elif path_type == \"lowest\":\n",
    "                inds_ = [s_qvs[i] for i in range(n_evaluation_samples)]\n",
    "            else:\n",
    "                inds_ = [s_qvs[torch.randint(0, len(s_qvs), (n_evaluation_samples,)).item()]]\n",
    "\n",
    "            # 3. Create plot objects\n",
    "            _p_x     = subtract_initial_point(x)\n",
    "            ts_cond  = _p_x[0, :, 0].cpu()/scale\n",
    "            ts_eval  = t_y[0, :, 0].cpu()/scale - 1\n",
    "            plot_x   = _p_x[inds_, :, 1:].cpu()\n",
    "            plot_y   = t_y[inds_, :, 1:].cpu()\n",
    "            plot_g_y = t_gen_y[inds_, :, :, 1:].cpu()\n",
    "\n",
    "        # 4. Add initial point\n",
    "        n_ts_cond = len(ts_cond)\n",
    "        n_ts_eval = len(ts_eval)\n",
    "        tv_x      = plot_x[:, -1, :]\n",
    "        ip_add    = torch.tile(tv_x, (1, eval_length)).unsqueeze(-1) \n",
    "        plot_y   += ip_add\n",
    "        plot_g_y += torch.tile(ip_add.unsqueeze(1), (1, num_eval_paths, 1, 1))\n",
    "        # 5. Loop over conditions and plot\n",
    "        #make_grid(axis=ax)\n",
    "\n",
    "        #ax.set_ylim([-20, 20])\n",
    "\n",
    "        gen_first  = True\n",
    "        cond_first = True\n",
    "        true_first = True\n",
    "\n",
    "        for px, py, gpy in zip(plot_x, plot_y, plot_g_y):\n",
    "            label = r\"$x$\" if cond_first else \"\"\n",
    "            ax.plot(ts_cond, px, color=\"black\", alpha=0.25, label=label)\n",
    "            cond_first = False\n",
    "            for pgcf in gpy:\n",
    "                label = r\"$\\mathbb{P}_{X^\\theta}(\\cdot|x)$\" if gen_first else \"\"\n",
    "                ax.plot(ts_eval, pgcf, color=\"dodgerblue\", alpha=0.1, label=label)\n",
    "                gen_first = False\n",
    "            ax.plot(ts_eval, py, color=\"black\", linestyle=\"dashed\", alpha=0.5, label=r\"$y \\sim \\mathbb{P}_{X^\\mathrm{true}}(\\cdot|x)$\" if true_first else \"\")\n",
    "            ax.plot(ts_eval, gpy[..., 0].min(axis=0).values, color=\"dodgerblue\", alpha=0.5)\n",
    "            ax.plot(ts_eval, gpy[..., 0].max(axis=0).values, color=\"dodgerblue\", alpha=0.5)\n",
    "\n",
    "            true_first = False\n",
    "        ax.legend(fontsize=18)\n",
    "        ax.axvline(ts_eval[0], color=\"grey\", linestyle=\"dashed\", alpha=0.25)\n",
    "        ax.set_xlabel(r\"$t$\", fontsize=16)\n",
    "        ax.set_ylabel(r\"$X_t$\", fontsize=16)\n",
    "        #ax.set_title(r\"Distribution of generated conditional distribution against realised path, EUR/USD\", fontsize=11)\n",
    "    \n",
    "    plt.tight_layout()\n",
    "    plt.savefig('cond_results.png', dpi=100)\n",
    "    plt.show()\n",
    "    "
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
