{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Training Neural SDEs via the signature kernel scoring rule: The unconditional case"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true,
    "pycharm": {
     "is_executing": true
    }
   },
   "source": [
    "### Background. \n",
    "\n",
    "A <i>generative adversarial network</i> (GAN) is, broadly, a machine learning architecture which attempts to teach a generator network $G_\\theta: \\mathcal{Z} \\to \\Omega$ to learn an observed probabiity distribution $\\mathbb{P}_{X^{\\text{true}}} \\in \\mathcal{P}(\\Omega)$. Here, $(\\mathcal{Z}, \\mathcal{B}(\\mathcal{Z}))$ is a measure space equipped with some probability measure $\\mathbb{P}_\\mathcal{Z}$ and is often referred to as a <i>latent space</i>. \n",
    "\n",
    "Training is facilitated by a so-called discriminator $D_\\phi: \\mathcal{P}(\\Omega) \\times \\mathcal{P}(\\Omega) \\to \\mathbb{R}$ which serves as a critic to the output of the generator. The objective function for the generator-discriminator pair is given by\n",
    "\n",
    "\\begin{equation*}\n",
    "    \\mathcal{L}(\\theta) = \\min_\\theta \\max_\\phi D_\\phi(\\mathbb{P}_{X^\\theta}, \\mathbb{P}_{X^{\\text{true}}}) + \\lambda ||\\theta||_{L^2},\n",
    "\\end{equation*}\n",
    "\n",
    "where $\\mathbb{P}^\\theta = {G_\\theta}_{\\#}\\mathbb{P}_\\mathcal{Z}$ is the pushforward measure, and the final term is a regularisation penalty on the weights of the discriminator. \n",
    "\n",
    "### Training Neural SDEs via signature kernel scoring rules. \n",
    "\n",
    "We propose a generator-discriminator pair consisting of a neural SDE (NSDE) and the signature kernel scoring rule $\\phi_{\\text{sig}}$. The generator $G_\\theta: \\mathcal{Z} \\to C([0, T]; \\mathbb{R}^x)$ is given by the solution to the SDE \n",
    "\n",
    "\\begin{align*}\n",
    "    dY_t &= \\mu_\\theta(t, Y_t)dt + \\sigma_\\theta(t, Y_t)\\circ dW_t, \\\\ \n",
    "    Y_0  &= \\xi_\\theta(V), \\\\\n",
    "    X_t  &= \\pi_\\theta(Y_t),\n",
    "\\end{align*}\n",
    "\n",
    "where $V \\sim \\mathbb{P}_\\mathcal{V}$ is some starting $v$-dimensional noise, $\\xi_\\theta: \\mathbb{R}^v \\to \\mathbb{R}^y$ a linear layer, $\\mu_\\theta: [0, T] \\times \\mathbb{R}^y \\to \\mathbb{R}^y$ and $\\sigma_\\theta: [0, T] \\times \\mathbb{R}^y \\to \\mathbb{R}^{y \\times w}$ are feedforward networks, and $\\pi_\\theta: \\mathbb{R}^y \\to \\mathbb{R}^x$ is a linear layer. Solutions to the the SDE take values in $\\mathbb{R}^y$ - latent space - and the transformed back to the required target space. Here, $\\mathcal{Z} = \\mathcal{V} \\times C([0, T]; \\mathbb{R}^w)$. In this way $\\mathbb{P}_\\theta = {G_\\theta}_{\\#}({\\mathbb{P}_\\mathcal{V} \\times \\mathbb{W}})$, where $\\mathbb{W}$ is the Wiener measure.\n",
    "\n",
    "The discriminator is given by the signature kernel scoring rule \n",
    "\n",
    "\\begin{equation*}\n",
    "        \\phi_{\\text{sig}}(\\mathbb{P}, y) := \\mathbb{E}_{x, x' \\sim \\mathbb{P}}[k_{\\text{sig}}(x, x')] - 2\\mathbb{E}_{x \\sim \\mathbb{P}}[k_{\\text{sig}}(x, y)],\n",
    "\\end{equation*}\n",
    "\n",
    "and we train non-adversarially by seeking to minimise the loss over the expected scoring rule \n",
    "\n",
    "\\begin{equation*}\n",
    "   \\min_\\theta \\mathcal{L}(\\theta) \\quad \\text{where} \\quad \\mathcal{L}(\\theta) = \\mathbb{E}_{y\\sim \\mathbb P_{X^{\\text{true}}}}[\\phi_{\\text{sig}}(\\mathbb P_{X^\\theta}, y)].\n",
    "\\end{equation*}\n",
    "\n",
    "This is equivalent to training with respect to the maximum mean discrepancy (MMD) associated to the signature kernel $k_{\\text{sig}}$. \n",
    "\n",
    "<b>A note about adversariality</b>. Classical MMD-GANs seek adversarialise the learning problem by parametrizing the kernel as a neural network. We note that our results can be achieved without any form of adversarialization. However, if one wishes to train adversarially, we include the option to via adversarializing the path scaling parameter. This can be implemented by setting the <code>adversarial</code> variable to <code>True</code> in the configuration section. Adversarial training in this manner only works for the signature kernel-based training methods."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "is_executing": true
    }
   },
   "outputs": [],
   "source": [
    "import math\n",
    "from collections import OrderedDict\n",
    "\n",
    "import torch\n",
    "import torch.optim.swa_utils as swa_utils\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from IPython import display\n",
    "from tqdm import tqdm\n",
    "import scienceplots\n",
    "\n",
    "from src.rBergomi import rBergomi\n",
    "from src.gan import sde\n",
    "from src.gan.base import preprocess_real_data, get_real_data, get_synthetic_data, get_scheduler, \\\n",
    "    stopping_criterion, evaluate_loss, get_stopping_criterion_value\n",
    "from src.gan.generators import Generator\n",
    "from src.gan.discriminators import SigKerMMDDiscriminator, TruncatedDiscriminator, CDEDiscriminator\n",
    "from src.gan.output_functions import plot_loss, plot_results\n",
    "from src.utils.helper_functions.global_helper_functions import get_project_root\n",
    "from src.utils.helper_functions.data_helper_functions import subtract_initial_point, build_path_bank, get_scalings\n",
    "from src.utils.helper_functions.plot_helper_functions import make_grid\n",
    "from src.utils.transformations import Transformer\n",
    "\n",
    "plt.style.use(\"science\")\n",
    "plt.rcParams['axes.titlesize']  = 24\n",
    "plt.rcParams['xtick.labelsize'] = 18\n",
    "plt.rcParams['ytick.labelsize'] = 18\n",
    "plt.rcParams['axes.labelsize']  = 22"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "is_executing": true
    }
   },
   "outputs": [],
   "source": [
    "torch.manual_seed(0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Layout. \n",
    "\n",
    "This notebook is broken into four sections: \n",
    "\n",
    "1) <b>Data:</b> The measure to be learnt on pathspace. Currently we support training a gBm, rBergomi model, or foreign exchange data. \n",
    "\n",
    "2) <b>Generator:</b> This will always be a neural SDE.\n",
    "\n",
    "3) <b>Discriminator:</b> There are three options here. The first is our Sig-MMD. The second is the truncated, linear MMD used in \"Sig-Wasserstein GANs for time series generation\" (Ni et al., 2021). The third is the Wasserstein distance from \"Neural SDEs as infinite-dimensional GANs\" (Kidger et al., 2021). \n",
    "\n",
    "4) <b>Training the NSDE and evaluation:</b> The NSDE is then trained. Perfomance evaluation can be found in the requisite notebooks."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "is_cuda = torch.cuda.is_available()\n",
    "device = 'cuda' if is_cuda else 'cpu'\n",
    "\n",
    "if not is_cuda:\n",
    "    print(\"Warning: CUDA not available; falling back to CPU but this is likely to be very slow.\")\n",
    "    \n",
    "# You realistically need GPU access (either natively or via cloud computing) to run this notebook."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 0. Configuration\n",
    "\n",
    "This is the universal configuration which sets the data loading, generator and discriminator hyperparameters, training types, and so on. Each value is annotated."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "### Data arguments\n",
    "## Data hyperparameters\n",
    "output_dim          = 1               # Dimension of the outputs.\n",
    "batch_size          = 128             # Size of minibatches when training.\n",
    "path_length         = 64              # Number of timesteps in paths.\n",
    "dataset_size        = batch_size*256  # Number of paths in the train/test datasets.\n",
    "normalisation       = \"mean_var\"      # How to transform paths. Options are None, \"mean_var\", \"min_max\"\n",
    "scale               = 1e0             # Float scale to each path dimension.\n",
    "\n",
    "## Data generation/dataset arguments\n",
    "data_type             = \"gbm\"                           # Choice of \"gbm\", \"rBergomi\", \"forex\".\n",
    "\n",
    "forex_pairs           = [\"EURUSD\", \"USDJPY\"]            # \"EURUSD\", \"USDJPY\", \"BTCUSD\", \"BRENTCMDUSD\"\n",
    "stride_length         = 1                               # Number of timesteps to stride when creating path data.\n",
    "frequency             = \"H1\"                            # Choice of \"H1\", \"M15\", \"M30\"\n",
    "filter_extremal_paths = False                           # Filter out extremal values (by tv and terminal value)\n",
    "filter_extremal_pct   = 0.95                            # Quantile to filter out, if the above is TRUE\n",
    "end_time              = 2                               # End time to simulate until (synthetic examples only)\n",
    "tt_split              = 0.8                             # Train/test split\n",
    "\n",
    "\n",
    "gbm_params               = [0., 0.2]\n",
    "rB_params                = [0.2**2, 1.5, -0.7, 0.2]\n",
    "cond_                    = data_type == \"rBergomi\"\n",
    "\n",
    "sde_parameters           = gbm_params if cond_ else rB_params\n",
    "gen_sde_dt_scale         = 1e-1                         # Refine the partitioning of SDE solver that generates data.\n",
    "gen_sde_integration_type = \"ito\"                        # Definition of stochastic integration for data generator\n",
    "gen_sde_method           = \"srk\"                        # Integration method\n",
    "\n",
    "learning_type            = \"paths\"                      # Paths or returns.\n",
    "time_add_type            = \"basic\"                      # \"Basic\": (0, 1, ...). \"Realistic\": actual change (in years)\n",
    "filter_by_time           = True                         # In the case of real data, whether to filter out missampled paths\n",
    "initial_point            = \"scale\"                      # How to handle initial point normalisation. \"scale\" or \"translate\"\n",
    "\n",
    "### Path transformation types/arguments\n",
    "do_transforms            = True  # Whether to apply path transformations\n",
    "\n",
    "transformations     = OrderedDict([\n",
    "    (\"visibility\"        , False), \n",
    "    (\"time_difference\"   , False), \n",
    "    (\"time_normalisation\", False),     # Set this to FALSE when training wrt wasserstein_cde. \n",
    "    (\"lead_lag\"          , False), \n",
    "    (\"basepoint\"         , False)\n",
    "])\n",
    "\n",
    "transformation_args = OrderedDict([\n",
    "    (\"visibility\"        , {}), \n",
    "    (\"time_difference\"   , {}), \n",
    "    (\"time_normalisation\", {}), \n",
    "    (\"lead_lag\"          , {\n",
    "        \"time_in\"           : True, \n",
    "        \"time_out\"          : False, \n",
    "        \"time_normalisation\": False\n",
    "    }), \n",
    "    (\"basepoint\"         , {})\n",
    "])\n",
    "\n",
    "subtract_start = True  # Subtract initial point before calculating loss. You almost always want this to be true.\n",
    "\n",
    "### Generator arguments\n",
    "generator_config = {\n",
    "    \"initial_noise_size\" : 5,                 # How many noise dimensions to sample at the start of the SDE.\n",
    "    \"noise_size\"         : 8,                 # How many dimensions the Brownian motion has.\n",
    "    \"hidden_size\"        : 16,                # Size of the hidden state of the generator SDE.\n",
    "    \"mlp_size\"           : 64,                # Size of the layers in the various MLPs.\n",
    "    \"num_layers\"         : 3,                 # Numer of hidden layers in the various MLPs.\n",
    "    \"activation\"         : \"LipSwish\",        # Activation function to use over hidden layers\n",
    "    \"tanh\"               : True,              # Whether to apply final tanh activation\n",
    "    \"tscale\"             : 1,                 # Clip parameter for tanh, i.e. [-1, 1] to [-c, c]\n",
    "    \"fixed\"              : True,              # Whether to fix the starting point or not\n",
    "    \"noise_type\"         : \"general\",         # Noise type argument for torchsde\n",
    "    \"sde_type\"           : \"ito\",             # SDE integration type from torchsde\n",
    "    \"dt_scale\"           : 1e0,               # Grid shrinking parameter. Lower values are computationally more expensive\n",
    "    \"integration_method\" : \"euler\"            # Integration method for torchsde\n",
    "}\n",
    "\n",
    "### Discriminator args\n",
    "discriminator_type = \"sigker_mmd\"   # Options are \"sigker_mmd\", \"wasserstein_cde\", \"truncated_mmd\"\n",
    "adversarial        = False          # Whether to adversarially train the discriminator or not.\n",
    "clip_disc_param    = False          # MMD-based discriminators only: whether to ensure the scaling param stays above 1   \n",
    "\n",
    "\n",
    "## sigker_mmd args\n",
    "sigker_mmd_config = {\n",
    "    \"dyadic_order\"   : 1,         # Mesh size of PDE solver used in loss function\n",
    "    \"kernel_type\"    : \"rbf\",     # Type of kernel to use in the discriminator\n",
    "    \"sigma\"          : 1e0,       # Sigma in RBF kernel\n",
    "    \"use_phi_kernel\" : False,     # Whether to use the the phi(k) = (k/2)! scaling. Set \"kernel_type\" to \"linear\".\n",
    "    \"n_scalings\"     : 3,         # Number of samples to draw from Exp(1). ~8 tends to be a good choice.\n",
    "    \"max_batch\"      : 16         # Maximum batch size to pass through the discriminator.\n",
    "}\n",
    "\n",
    "## truncated_mmd args\n",
    "truncated_mmd_config = {\n",
    "    \"order\"         : 6,         # Truncation level\n",
    "    \"scalar_term\"   : False,     # Whether to include the leading 1 term in the signature.\n",
    "}\n",
    "\n",
    "\n",
    "## wasserstein_ode args\n",
    "wasserstein_cde_config = {\n",
    "    \"hidden_size\" : 16,          # Number of hidden states in CDE solver\n",
    "    \"num_layers\"  : 3,           # Number of layers in MLPs of NCDE\n",
    "    \"mlp_size\"    : 32           # Number of neurons in each layer\n",
    "}\n",
    "\n",
    "### Training hyperparameters\n",
    "## Optimizer parameters\n",
    "generator_lr     = 1e-03         # Generator initial learning rate\n",
    "discriminator_lr = 2e-03         # Discriminator initial learning rate\n",
    "steps            = 1000          # How many steps to train both generator and discriminator for.\n",
    "init_mult1       = 1             # Changing the initial parameter size can help.\n",
    "init_mult2_dr    = 5e-1          # Changing vector field MLP initial parameter size.\n",
    "init_mult2_df    = 5e-1          # Changing vector field MLP initial parameter size.\n",
    "init_mult3       = 1             # Initial parameter size for discriminator.\n",
    "weight_decay     = 1e-2          # Weight decay (regularizing term).\n",
    "swa_step_start   = int(steps/2)  # When to start using stochastic weight averaging (L2).\n",
    "gen_optim        = \"Adadelta\"        # Optimiser type for generator\n",
    "disc_optim       = \"Adadelta\"    # Optimiser type for discriminator\n",
    "loss_evals       = 1             # Number of times to evaluate the loss before performing an optimizer step\n",
    "\n",
    "## Annealing learning rate parameters\n",
    "adapting_lr            = False          # Whether to make the learning rate adaptive.\n",
    "adapting_lr_type       = \"OneCycleLR\"   # LR scheduler type\n",
    "lambda_lr_const        = 0.5            # If LambdaLR, learning rate fraction to reduce to\n",
    "poly_exponent_smoother = -0.5           # If LambdaLR, poly exponent to decrease to\n",
    "mult_const             = 1.01           # If MultiplicativeLR, val of a for \\eta_{t+1} = a\\eta_t\n",
    "gamma_lr               = 0.25           # If StepLR, multiplier for learning rate\n",
    "steps_lr               = int(steps/10)  # If StepLR, when to change learning rate\n",
    "max_lr                 = 1e-06          # OneCycleLR: Maximum rate\n",
    "anneal_strategy        = \"cos\"          # OneCycleLR: Anneal type \n",
    "total_steps            = steps          # OneCycleLR: Anneal rate\n",
    "pct_start              = 0.3            # OneCycleLR: Percentage of schedule increasing rate\n",
    "div_factor             = 10             # OneCycleLR: Final lr as a percentage of max\n",
    "\n",
    "## Early stopping\n",
    "early_stopping_type = \"marginals\"  # Early stopping type ('marginals' or 'mmd')\n",
    "crit_evals          = 20           # Marginals: number of evaluations of criterion\n",
    "crit_thresh         = 0.99         # Marginals: stopping threshold\n",
    "cutoff              = 1.           # Marginals: cutoff of marginal distributions to remove extremal values\n",
    "mmd_ci              = 0.95         # mmd: Confidence on null distribution\n",
    "mmd_atoms           = steps_lr     # mmd: number of atoms to build null distribution\n",
    "mmd_periods         = 50           # mmd: number of lagged periods to compare threshold to\n",
    "\n",
    "## Evaluation and plotting hyperparameters\n",
    "steps_per_print  = int(steps/10)              # How often to print the loss.\n",
    "num_plot_samples = int(batch_size/2)          # How many samples to use on the plots at the end.\n",
    "plot_locs        = (0.1, 0.3, 0.5, 0.7, 0.9)  # Plot some marginal distributions at this proportion of the way along."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. Data\n",
    "\n",
    "We begin by pre-loading and setting data arguments. If real data is being used, some preprocessing needs to be completed so the data is in the right form to be wrapped as <code>torch</code> dataloader objects."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# If real data, needs to be loaded prior to being packaged up.\n",
    "\n",
    "data_kwargs = {\n",
    "    \"dataset_size\" : dataset_size,\n",
    "    \"path_length\"  : path_length,\n",
    "    \"batch_size\"   : batch_size,\n",
    "    \"step_size\"    : stride_length,\n",
    "    \"learning_type\": learning_type,\n",
    "    \"time_add_type\": time_add_type,\n",
    "    \"initial_point\": initial_point,\n",
    "    \"train_test_split\": tt_split\n",
    "}\n",
    "\n",
    "real_data_kwargs = {\n",
    "    \"pairs\"                : forex_pairs,\n",
    "    \"frequency\"            : frequency,\n",
    "    \"filter_extremal_paths\": filter_extremal_paths,\n",
    "    \"filter_extremal_pct\"  : filter_extremal_pct\n",
    "    \n",
    "}\n",
    "\n",
    "if data_type.lower() == \"forex\":  # Real data arguments\n",
    "    np_train_data, np_test_data = preprocess_real_data(data_kwargs, real_data_kwargs)\n",
    "    \n",
    "    ts, data_size, train_dataloader = get_real_data(\n",
    "        np_train_data, \n",
    "        batch_size, \n",
    "        dataset_size, \n",
    "        device, \n",
    "        time_add_type  = time_add_type, \n",
    "        normalisation  = normalisation,\n",
    "        filter_by_time = filter_by_time,\n",
    "        initial_point  = False,\n",
    "        scale          = scale\n",
    "    )\n",
    "    \n",
    "    _, _, test_dataloader = get_real_data(\n",
    "        np_test_data, \n",
    "        batch_size, \n",
    "        dataset_size, \n",
    "        device, \n",
    "        time_add_type  = time_add_type, \n",
    "        normalisation  = normalisation,\n",
    "        filter_by_time = filter_by_time,\n",
    "        initial_point  = False,\n",
    "        scale          = scale\n",
    "    )\n",
    "\n",
    "elif data_type.lower() in [\"gbm\", \"rbergomi\"]:\n",
    "    \n",
    "    sdeint_kwargs = {\n",
    "        \"sde_method\":   gen_sde_method,\n",
    "        \"sde_dt_scale\": gen_sde_dt_scale\n",
    "    }\n",
    "    \n",
    "    if data_type.lower() == \"gbm\":\n",
    "        sde_gen = sde.GeometricBrownianMotion(gen_sde_integration_type, \"diagonal\", *sde_parameters)\n",
    "    elif data_type.lower() == \"rbergomi\":\n",
    "        xi, eta, rho, H = sde_parameters\n",
    "        sde_gen = rBergomi(n=int(path_length/end_time), N=dataset_size, T=end_time, a=H-0.5, rho=rho, eta=eta, xi=xi)\n",
    "    \n",
    "    ts, data_size, train_dataloader = get_synthetic_data(\n",
    "        sde_gen,\n",
    "        batch_size,\n",
    "        dataset_size,\n",
    "        device,\n",
    "        output_dim,\n",
    "        path_length,\n",
    "        normalisation = normalisation,\n",
    "        scale         = scale,\n",
    "        sdeint_kwargs = sdeint_kwargs,\n",
    "        end_time      = end_time,\n",
    "        time_add_type = time_add_type\n",
    "    )\n",
    "    \n",
    "    _, _, test_dataloader = get_synthetic_data(\n",
    "        sde_gen,\n",
    "        batch_size,\n",
    "        dataset_size,\n",
    "        device,\n",
    "        output_dim,\n",
    "        path_length,\n",
    "        normalisation = normalisation,\n",
    "        scale         = scale,\n",
    "        sdeint_kwargs = sdeint_kwargs,\n",
    "        end_time      = end_time,\n",
    "        time_add_type = time_add_type\n",
    "    )\n",
    "    \n",
    "infinite_train_dataloader = (elem for it in iter(lambda: train_dataloader, None) for elem in it)\n",
    "\n",
    "transformer               = Transformer(transformations, transformation_args, device).to(device) if do_transforms else lambda x: x"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### A note about normalisation and scaling.\n",
    "\n",
    "Scaling and normalising the initial data is a very important consideration. Heuristically, this is for the following reasons:\n",
    "\n",
    "1) <b>The signature kernel PDE</b>. Accurate gradients are required when backpropagating through the discriminator, which as mentioned is the MMD associated to the signature kernel. The signature kernel is the solution to a PDE, where solutions are more \"accurate\" in practice when the state space traversed is of lower numerical order.\n",
    "\n",
    "2) <b>Distribution of signature terms in the data</b>. The signature kernel is a dot product in $T((\\mathbb{R}^d))$ between expected signature terms. Reducing the MMD loss between two groups of paths can be thought of as reducing this dot product. If a certain order of terms in the signature is too large relative to the others, then it is unlikely other moments will be learnt when backpropagting. For example, if the first order terms are too large (the \"drift\"), then path increments will be learnt at the cost of other moments. If higher-order terms dominate, then lower-order terms are not learnt.  \n",
    "\n",
    "3) <b>Initial distribution of generator</b>. Changing scalings and normalisations will change the initial distribution of paths given by the generator. A quick visual plot should be enough for one to check if the size of the generator's parameters are appropriate for the given problem. The parameter <code>init_mult2</code> is the one to focus on here."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "STUDY_SIGNATURE_TERMS = False\n",
    "\n",
    "if STUDY_SIGNATURE_TERMS:\n",
    "    \n",
    "    sigmas      = sigker_mmd_config.get(\"sigma\")\n",
    "    kernel_type = sigker_mmd_config.get(\"kernel_type\")\n",
    "    \n",
    "    pow_func = lambda x, y: torch.pow(x, y)/math.factorial(y)\n",
    "    bnd_func = lambda x, y: pow_func(torch.max(torch.sum(torch.abs(torch.diff(x, axis=1)), axis=1)), y)\n",
    "    \n",
    "    do_theo_bnd = False\n",
    "    \n",
    "    with torch.no_grad():\n",
    "        \n",
    "        paths,     = next(infinite_train_dataloader)\n",
    "        _, _, dims = paths.shape\n",
    "        dims -= 1\n",
    "        \n",
    "        paths = paths.cpu()\n",
    "        \n",
    "        if not isinstance(sigmas, list):\n",
    "            sigmas = np.array([sigmas])\n",
    "        \n",
    "        for k in range(dims):\n",
    "            ex_title = f\", dim = {k+1}\" if dims != 1 else \"\"\n",
    "\n",
    "            fig, axes = plt.subplots(1, len(sigmas), figsize=(6*len(sigmas), 3))\n",
    "\n",
    "            if len(sigmas) == 1:\n",
    "                axes = np.array([axes])\n",
    "\n",
    "            for ax, sig in zip(axes, sigmas):\n",
    "\n",
    "                ex_sig = f\", sigma = {sig}\" if kernel_type ==\"rbf\" else \"\"\n",
    "\n",
    "                increments        = torch.abs(paths[:, -1, k+1] - paths[:, 0, k+1])\n",
    "                scaled_increments = increments/np.sqrt(sig) if kernel_type == \"rbf\" else increments\n",
    "\n",
    "                powers = np.arange(1, 10).astype(int)\n",
    "\n",
    "                moment_terms = torch.tensor([[pow_func(inc, d) for d in powers] for inc in scaled_increments])\n",
    "                moment_means = moment_terms.mean(axis=0)\n",
    "                moment_stds  = moment_terms.std(axis=0)\n",
    "                \n",
    "                ax.plot(powers, moment_means, color=\"dodgerblue\", alpha=0.75, label=\"moment_means\")\n",
    "                if do_theo_bnd:\n",
    "                    level_bds    = [bnd_func(paths[..., k+1], d) for d in powers]\n",
    "                    ax.plot(powers, level_bds, color=\"tomato\", linestyle=\"dashed\", alpha=0.75, label=\"theo_bnd\")\n",
    "                    \n",
    "                ax.fill_between(powers, moment_means-moment_stds, moment_means+moment_stds, color=\"dodgerblue\", alpha=0.25)\n",
    "                make_grid(axis=ax)\n",
    "                ax.set_title(\"Scaled increment values (moment ratio)\" + ex_sig + ex_title, fontsize=\"small\")\n",
    "                ax.legend(fontsize=\"small\");"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. Generator\n",
    "\n",
    "As mentioned in the introduction, the generator is given by a neural SDE."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "generator = Generator(data_size=data_size, **generator_config).to(device)\n",
    "\n",
    "with torch.no_grad():\n",
    "    if (not generator_config.get(\"fixed\")) or (learning_type == \"returns\"):\n",
    "        for prm in generator._initial.parameters():\n",
    "            prm = prm*init_mult1\n",
    "\n",
    "    for name, prm in generator._func.named_parameters():\n",
    "        if \"_drift\" in name:\n",
    "            prm.data*=init_mult2_dr\n",
    "        else:\n",
    "            prm.data*=init_mult2_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "CHECK_INIT_GENERATOR = True\n",
    "\n",
    "if CHECK_INIT_GENERATOR:\n",
    "    \n",
    "    dims = output_dim if data_type != \"forex\" else len(forex_pairs)\n",
    "    \n",
    "    fig, axes = plt.subplots(dims, 1, figsize=(6, dims*3))\n",
    "    \n",
    "    with torch.no_grad():\n",
    "        y, = next(infinite_train_dataloader)\n",
    "        x  = generator(ts, batch_size)\n",
    "        \n",
    "        if subtract_start:\n",
    "            y  = subtract_initial_point(y)\n",
    "            x  = subtract_initial_point(x)\n",
    "        \n",
    "    if not isinstance(axes, np.ndarray):\n",
    "        axes = np.array([axes])\n",
    "    \n",
    "    for k, ax in enumerate(axes):\n",
    "\n",
    "        N = 32\n",
    "        x_plot = x[:N, :, k+1].cpu()\n",
    "        y_plot = y[:N, :, k+1].cpu()\n",
    "\n",
    "        generated_first = True\n",
    "        real_first      = True\n",
    "\n",
    "        for xi, yi in zip(x_plot, y_plot):\n",
    "            g_kwargs = {\"label\": \"generated\"} if generated_first else {}\n",
    "            r_kwargs = {\"label\": \"real\"} if generated_first else {}\n",
    "\n",
    "            ax.plot(xi, color=\"tomato\", alpha=0.5, **g_kwargs)\n",
    "            ax.plot(yi, color=\"dodgerblue\", alpha=0.5, **r_kwargs)\n",
    "\n",
    "            generated_first = False\n",
    "            real_first      = False\n",
    "\n",
    "        ax.legend()\n",
    "        make_grid(axis=ax)\n",
    "        ax_title = \"\" if dims == 1 else f\"Dim {k+1}\"\n",
    "        ax.set_title(ax_title, fontsize=\"small\")\n",
    "    fig.suptitle(\"Initialisation of $G$ against real data\")\n",
    "    plt.tight_layout()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. Discriminator\n",
    "\n",
    "We have three options here - the signature kernel MMD, the truncated MMD (Ni et al, 2021), and the Wasserstein distance as witnessed by a neural ODE (Kidger et al, 2021)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if discriminator_type.lower() == \"sigker_mmd\":\n",
    "    discriminator = SigKerMMDDiscriminator(\n",
    "        path_dim     = data_size, \n",
    "        adversarial  = adversarial,\n",
    "        **sigker_mmd_config\n",
    "    ).to(device)\n",
    "    \n",
    "elif discriminator_type.lower() == \"truncated_mmd\":\n",
    "    discriminator = TruncatedDiscriminator(\n",
    "        path_dim    = data_size,\n",
    "        adversarial = adversarial,\n",
    "        **truncated_mmd_config\n",
    "    ).to(device)\n",
    "    \n",
    "elif discriminator_type.lower() == \"wasserstein_cde\":\n",
    "    discriminator = CDEDiscriminator(\n",
    "        data_size = data_size,\n",
    "        **wasserstein_cde_config\n",
    "    ).to(device)\n",
    "else:\n",
    "    discriminator = None\n",
    "    print(\"Discriminator does not exist.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### A note about the sigma scaling parameter, and the distribution of the MMD.\n",
    "\n",
    "\n",
    "It is helpful to \"calibrate\" the starting sigma, for the same reasons mentioned in the above note."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "GENERATE_NULL_MMD = False\n",
    "\n",
    "if GENERATE_NULL_MMD and (discriminator_type != \"wasserstein_cde\"):\n",
    "    with torch.no_grad():\n",
    "        mmd_atoms = batch_size\n",
    "        true_mmds = np.zeros(mmd_atoms)\n",
    "\n",
    "        for i in tqdm(range(mmd_atoms)):\n",
    "            x, = next(infinite_train_dataloader)\n",
    "            y, = next(infinite_train_dataloader)\n",
    "\n",
    "            if subtract_start:\n",
    "                x = subtract_initial_point(x)\n",
    "                y = subtract_initial_point(y)\n",
    "\n",
    "            x = transformer(x)\n",
    "            y = transformer(y)\n",
    "\n",
    "            true_mmds[i] = discriminator(x, y.detach())\n",
    "\n",
    "        ci = sorted(true_mmds)[int(mmd_ci*mmd_atoms)]\n",
    "\n",
    "        fig, ax = plt.subplots(1, 1, figsize=(10, 5))\n",
    "        plt.hist(sorted(true_mmds), bins=int(mmd_atoms/10), alpha=0.6, color=\"dodgerblue\", density=True)\n",
    "        make_grid()\n",
    "        print(f\"{100*mmd_ci:.0f}% CI: {ci:.5e}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4. Optimisers and learning rate annealers\n",
    "\n",
    "Here we define optimisers for the generator (or the discriminator), and some learning rate annealers."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "averaged_generator     = swa_utils.AveragedModel(generator)\n",
    "averaged_discriminator = swa_utils.AveragedModel(discriminator)\n",
    "\n",
    "generator_optimiser     = getattr(torch.optim, gen_optim)(\n",
    "    generator.parameters(), lr=generator_lr, weight_decay=weight_decay\n",
    ")\n",
    "\n",
    "generator_optimiser.zero_grad()\n",
    "\n",
    "if adversarial or (discriminator_type == \"wasserstein_cde\"):\n",
    "    discriminator_optimiser_ = getattr(torch.optim, disc_optim)\n",
    "    if discriminator_type == \"weighted_sigker\":\n",
    "        discriminator_optimiser = discriminator_optimiser_([\n",
    "            {'params': discriminator._sigma},\n",
    "            {'params': discriminator._weights, \"lr\": discriminator_lr*5e3}\n",
    "        ], lr=discriminator_lr, weight_decay=weight_decay)\n",
    "        \n",
    "    discriminator_optimiser = discriminator_optimiser_(\n",
    "        discriminator.parameters(), lr=discriminator_lr, weight_decay=weight_decay\n",
    "    )\n",
    "    \n",
    "    discriminator_optimiser.zero_grad()\n",
    "else:\n",
    "    discriminator_optimiser = None\n",
    "        \n",
    "if adapting_lr:\n",
    "    if adapting_lr_type == \"LambdaLR\":\n",
    "        f_lmd = lambda epoch: (1-lambda_lr_const)*np.power(epoch + 1., poly_exponent_smoother) + lambda_lr_const\n",
    "        adpt_kwargs = {\"lr_lambda\": f_lmd}\n",
    "    elif adapting_lr_type == \"StepLR\":\n",
    "        adpt_kwargs = {\"gamma\": gamma_lr, \"step_size\":steps_lr}\n",
    "    elif adapting_lr_type == \"MultiplicativeLR\":\n",
    "        adpt_kwargs = {\"lr_lambda\": lambda epoch: mult_const}\n",
    "        \n",
    "    elif adapting_lr_type == \"OneCycleLR\":\n",
    "        adpt_kwargs = {\n",
    "            \"max_lr\": max_lr, \n",
    "            \"total_steps\": total_steps, \n",
    "            \"pct_start\": pct_start, \n",
    "            \"anneal_strategy\": anneal_strategy,\n",
    "            \"div_factor\": div_factor\n",
    "        }\n",
    "    \n",
    "    g_scheduler, d_scheduler = get_scheduler(\n",
    "        generator_optimiser, \n",
    "        discriminator_optimiser, \n",
    "        adapting_lr_type, \n",
    "        adversarial,\n",
    "        **adpt_kwargs\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 5. Training and evaluation\n",
    "\n",
    "We are now ready to proceed with training the GAN/moment-matching network, and evaluate its performance. \n",
    "\n",
    "The generator/discriminator pairs are saved when training is complete (and tagged appropriately) for future use."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "# Training procedure\n",
    "TRAIN_MODEL      = True\n",
    "update_plots     = True  # See how training is going in real time.\n",
    "update_freq      = steps_per_print\n",
    "num_plot_samples = 64\n",
    "log_scale        = False  # Whether to plot the log-loss (WARNING: if negative, plot will look odd.)\n",
    "plot_times       = ts.cpu() if (end_time is None) or (data_type == \"forex\") else torch.linspace(0, end_time, path_length)\n",
    "\n",
    "sigmas           = torch.zeros((steps, data_size))\n",
    "#weights          = torch.zeros((steps, len(_weights)))\n",
    "\n",
    "do_total_loss    = False  # Compute total loss on dataset when printing progressive results.\n",
    "ttl_loss_flag    = bool(discriminator_type == \"wasserstein_cde\")\n",
    "\n",
    "step_vec     = np.arange(steps)\n",
    "av_loss      = torch.zeros(steps, requires_grad=False, dtype=torch.float32)\n",
    "burn_counter = 0 \n",
    "scaling_chg = []\n",
    "\n",
    "gen_fp = get_project_root().as_posix() + f\"/saved_models/generators/{data_type}_{path_length}_{batch_size}_{output_dim}_{discriminator_type}\"\n",
    "disc_fp = get_project_root().as_posix() + f\"/saved_models/discriminators/{data_type}_{path_length}_{batch_size}_{output_dim}_{discriminator_type}\"\n",
    "\n",
    "if TRAIN_MODEL:\n",
    "    tr_loss = torch.zeros(steps, requires_grad=False).to(device)\n",
    "    criterions = []\n",
    "\n",
    "    trange = tqdm(range(steps), position=0, leave=True)\n",
    "\n",
    "    for step in trange:\n",
    "\n",
    "        ###############################################################################\n",
    "        ## 1. Calculate loss\n",
    "        ###############################################################################  \n",
    "        \n",
    "        real_samples,     = next(infinite_train_dataloader)\n",
    "        real_samples      = transformer(real_samples).float()\n",
    "\n",
    "        generated_samples = generator(ts, batch_size)\n",
    "        generated_samples = transformer(generated_samples)\n",
    "\n",
    "        if subtract_start:\n",
    "            real_samples      = subtract_initial_point(real_samples)\n",
    "            generated_samples = subtract_initial_point(generated_samples)\n",
    "\n",
    "        if discriminator_type == \"wasserstein_cde\":\n",
    "            gen_score  = discriminator(generated_samples)\n",
    "            real_score = discriminator(real_samples)\n",
    "            loss       = gen_score - real_score\n",
    "        else:\n",
    "            loss = discriminator(generated_samples, real_samples.detach())\n",
    "\n",
    "        loss.backward()\n",
    "\n",
    "        tr_loss[step] += loss.detach().clone()\n",
    "        burn_counter  += 1\n",
    "        \n",
    "        # Stepped sig_ker discriminator: changing the scaling factor\n",
    "        if False:\n",
    "            with torch.no_grad():\n",
    "                if step < _change_window:\n",
    "                    av_loss[step] = tr_loss[step].item()\n",
    "                    st_ind = step\n",
    "                else:\n",
    "                    st_ind = int(step-_change_window)\n",
    "                    av_loss[step] = torch.mean(tr_loss[st_ind:step]).item()\n",
    "\n",
    "                # Changing the scaling factor\n",
    "                if (discriminator_type == \"stepped_sigker\") and (burn_counter >= _burn_in):\n",
    "                    # Check the exit condition\n",
    "                    if (av_loss[step]/av_loss[st_ind] - 1 > _change_threshold) and (step >= _burn_in): \n",
    "                        curr_scaling = discriminator._scaling\n",
    "                        discriminator.set_scaling(curr_scaling**2 + _change_factor)\n",
    "                        scaling_chg.append(step)\n",
    "                        burn_counter = 0\n",
    "\n",
    "        ###############################################################################\n",
    "        ## 2. Plotting temporal results\n",
    "        ###############################################################################\n",
    "        \n",
    "        if (update_plots) and ((step % update_freq == 0) or (step == steps-1)):\n",
    "            with torch.no_grad():\n",
    "                fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 10))\n",
    "\n",
    "                # Plot updated loss\n",
    "                #with torch.no_grad():\n",
    "                np_tr_loss = tr_loss.cpu().numpy()\n",
    "\n",
    "                current_loss  = np_tr_loss[:step]\n",
    "                future_loss   = 0. if current_loss.size == 0 else current_loss.min() - np.std(current_loss)\n",
    "                current_steps = step_vec[:step]\n",
    "                future_steps  = step_vec[step:]\n",
    "                future_loss   = np.array([future_loss for _ in range(future_steps.shape[0])])\n",
    "\n",
    "                ax1.plot(current_steps, current_loss, alpha=1., color=\"dodgerblue\", label=\"training_loss\")\n",
    "                ax1.plot(future_steps, future_loss, alpha=0.)\n",
    "\n",
    "                if adapting_lr and (adapting_lr_type == \"StepLR\"):\n",
    "                    lr_changes = np.array([steps_lr*(i+1) for i in range(int(step/steps_lr))])\n",
    "                    lr_plot = True\n",
    "                    for lr_c in lr_changes:\n",
    "                        kwargs = {\"label\": \"lr_change\"} if lr_plot else {}\n",
    "                        ax1.axvline(lr_c, linestyle=\"dashed\", alpha=0.5, color=\"grey\", **kwargs)\n",
    "                        lr_plot = False\n",
    "                        \n",
    "                if (discriminator_type == \"stepped_sigker\"):\n",
    "                    sc_label = True\n",
    "                    for sc in scaling_chg:\n",
    "                        kwargs = {\"label\": \"scaling_change\"} if sc_label else {}\n",
    "                        ax1.axvline(sc, linestyle=\"dashed\", alpha=0.5, color=\"grey\", **kwargs)\n",
    "                        sc_label = False\n",
    "\n",
    "                if log_scale:\n",
    "                    ax1.set_yscale(\"log\")\n",
    "                    \n",
    "                make_grid(axis=ax1)\n",
    "                ax1.legend()\n",
    "\n",
    "                if do_transforms:\n",
    "                    real_plot_samples          = transformer.backward(real_samples).cpu()\n",
    "                    generated_plot_samples = transformer.backward(generated_samples).cpu()\n",
    "                else:\n",
    "                    real_plot_samples = real_samples.cpu()    \n",
    "                    generated_plot_samples = generated_plot_samples.cpu()\n",
    "\n",
    "                if subtract_start:\n",
    "                    real_plot_samples      = subtract_initial_point(real_plot_samples)\n",
    "                    generated_plot_samples = subtract_initial_point(generated_plot_samples)\n",
    "\n",
    "                this_dim = np.random.randint(1, data_size + 1)\n",
    "\n",
    "                real_times             = real_plot_samples[..., 0]\n",
    "                real_plot_times        = real_times[:num_plot_samples]\n",
    "                real_plot_samples      = real_plot_samples[:num_plot_samples, :, this_dim]\n",
    "                generated_plot_samples = generated_plot_samples[:num_plot_samples, :, this_dim]\n",
    "                real_first      = True\n",
    "                generated_first = True\n",
    "                for i, real_sample_ in enumerate(real_plot_samples):\n",
    "                    kwargs = {'label': 'Real'} if real_first else {}\n",
    "                    ax2.plot(plot_times, real_sample_.cpu(), color='dodgerblue', linewidth=0.5, alpha=0.5, **kwargs)\n",
    "                    real_first = False\n",
    "                for generated_sample_ in generated_plot_samples:\n",
    "                    kwargs = {'label': 'Generated'} if generated_first else {}\n",
    "                    ax2.plot(plot_times, generated_sample_.cpu(), color='crimson', linewidth=0.5, alpha=0.5, **kwargs)\n",
    "                    generated_first = False\n",
    "                ax2.legend()\n",
    "                ax2.grid(visible=True, color='grey', linestyle=':', linewidth=1.0, alpha=0.3)\n",
    "                ax2.minorticks_on()\n",
    "                ax2.grid(visible=True, which='minor', color='grey', linestyle=':', linewidth=1.0, alpha=0.1)\n",
    "\n",
    "                cont_title = f\"Epoch {step}, {data_type} paths\"\n",
    "                extr_title = f\", {forex_pairs[int(this_dim-1)]}\" if data_type == \"forex\" else \"\"\n",
    "\n",
    "                fig.suptitle(cont_title + extr_title)\n",
    "\n",
    "                display.clear_output(wait=True)\n",
    "                display.display(plt.gcf())\n",
    "            \n",
    "        ###############################################################################\n",
    "        ## 3. Step through optimisers and adapting LR schedulers, stochastic weights\n",
    "        ###############################################################################\n",
    "        if adversarial or (discriminator_type == \"wasserstein_cde\"):\n",
    "            for param in discriminator.parameters():\n",
    "                param.grad *= -1\n",
    "            \n",
    "            discriminator_optimiser.step()\n",
    "            discriminator_optimiser.zero_grad()\n",
    "            \n",
    "            with torch.no_grad():\n",
    "                if discriminator_type in [\"sigker_mmd\", \"truncated_mmd\"]:\n",
    "                    \n",
    "                    if clip_disc_param:\n",
    "                        for param in discriminator.parameters():\n",
    "                            param.clamp_(1, 1e2)\n",
    "                            \n",
    "                    sigmas[step] = discriminator._sigma\n",
    "                    \n",
    "                elif discriminator_type == \"wasserstein_cde\":\n",
    "                    for module in discriminator.modules():\n",
    "                        if isinstance(module, torch.nn.Linear):\n",
    "                            lim = 1 / module.out_features\n",
    "                            module.weight.clamp_(-lim, lim)\n",
    "                elif discriminator_type == \"weighted_sigker\":\n",
    "                    sigmas[step] = discriminator._sigma\n",
    "                    weights[step] = discriminator._weights\n",
    "                    \n",
    "                    for name, param in discriminator.named_parameters():\n",
    "                        if name == \"_weights\":\n",
    "                            param.clamp(0, 1)\n",
    "                            param /= torch.sum(param)\n",
    "            \n",
    "            if adapting_lr:\n",
    "                d_scheduler.step()\n",
    "            \n",
    "        generator_optimiser.step()\n",
    "        generator_optimiser.zero_grad()\n",
    "        \n",
    "        if adapting_lr:\n",
    "            g_scheduler.step()\n",
    "\n",
    "        # Stochastic weight averaging of generator (and discriminator, doesn't matter when not adversarial)\n",
    "        if step > swa_step_start:\n",
    "            averaged_generator.update_parameters(generator)\n",
    "            averaged_discriminator.update_parameters(discriminator)\n",
    "            \n",
    "        ###############################################################################\n",
    "        ## 4. Early stopping criterions\n",
    "        ###############################################################################\n",
    "        if (step % steps_per_print) == 0 or step == steps - 1:\n",
    "            with torch.no_grad():\n",
    "                # Check if early exit\n",
    "                if early_stopping_type == \"marginals\":\n",
    "                    criterion = 0\n",
    "\n",
    "                    for _ in range(crit_evals):\n",
    "                        crit_generated_samples  = generator(ts, batch_size)\n",
    "                        criterion_samples,      = next(infinite_train_dataloader)\n",
    "\n",
    "                        if subtract_start:\n",
    "                            crit_generated_samples = subtract_initial_point(crit_generated_samples)\n",
    "                            criterion_samples      = subtract_initial_point(criterion_samples)\n",
    "\n",
    "                        criterion += stopping_criterion(criterion_samples, crit_generated_samples, cutoff=cutoff, print_results=False)\n",
    "                    av_criterion = criterion/crit_evals\n",
    "                    criterions.append(av_criterion)\n",
    "\n",
    "                    if av_criterion > crit_thresh:\n",
    "                        trange.write(\"Stopping criterion reached. Exiting training early.\")\n",
    "                        tr_loss[step:] = tr_loss[step]\n",
    "                        break\n",
    "                    crit_text = f\"Criterion: {av_criterion:.4f} Target: {crit_thresh:.4f}\"\n",
    "                elif (early_stopping_type == \"mmd\") and step > mmd_periods:\n",
    "                    averaged_mmd_score = torch.mean(tr_loss[step-mmd_periods:step])\n",
    "                    if averaged_mmd_score <= ci:\n",
    "                        trange.write(\"Stopping criterion reached. Exiting training early.\")\n",
    "                        tr_loss[step:] = tr_loss[step]\n",
    "                        break\n",
    "                    crit_text = f\"Criterion: {averaged_mmd_score:.5e} Target {ci:.5e}\"\n",
    "                else:\n",
    "                    crit_text = \"\"\n",
    "\n",
    "                # Print total loss on dataset\n",
    "                if do_total_loss:\n",
    "                    total_unaveraged_loss = evaluate_loss(\n",
    "                        ts, batch_size, train_dataloader, generator, discriminator, transformer, subtract_start, cde_disc=ttl_loss_flag\n",
    "                    )\n",
    "                else:\n",
    "                    total_unaveraged_loss = loss.item()\n",
    "\n",
    "                if step > swa_step_start:\n",
    "                    if do_total_loss:\n",
    "                        total_averaged_loss = evaluate_loss(ts, batch_size, train_dataloader, averaged_generator.module,\n",
    "                                                            averaged_discriminator, transformer, subtract_start, cde_disc=ttl_loss_flag)\n",
    "                    else:\n",
    "                        total_averaged_loss = total_unaveraged_loss\n",
    "                    trange.write(f\"Step: {step:3} Total loss (unaveraged): {total_unaveraged_loss:.5e} \"\n",
    "                                 f\"Loss (averaged): {total_averaged_loss:.5e} \" + crit_text)\n",
    "                else:\n",
    "                    trange.write(f\"Step: {step:3} Total loss (unaveraged): {total_unaveraged_loss:.5e} \" + crit_text)\n",
    "                \n",
    "    ###############################################################################\n",
    "    ## 5. Training complete\n",
    "    ################################b###############################################\n",
    "    torch.save(generator.state_dict(), gen_fp + \".pkl\")\n",
    "    torch.save(discriminator.state_dict(), disc_fp + \".pkl\")\n",
    "    torch.save(generator_config, gen_fp + \"_config.pkl\")\n",
    "    torch.save(eval(discriminator_type + \"_config\"), disc_fp + \"_config.pkl\")\n",
    "    \n",
    "    torch.save(averaged_generator.state_dict(), gen_fp + \"_averaged.pkl\")\n",
    "    torch.save(averaged_discriminator.state_dict(), disc_fp + \"_averaged.pkl\")\n",
    "        \n",
    "    plot_loss(tr_loss)\n",
    "else:\n",
    "    try:\n",
    "        generator_state_dict     = torch.load(gen_fp  + \".pkl\")\n",
    "        discriminator_state_dict = torch.load(disc_fp + \".pkl\")\n",
    "        generator_config         = torch.load(gen_fp  + \"_config.pkl\")\n",
    "\n",
    "        averaged_generator_state_dict     = torch.load(gen_fp + \"_averaged.pkl\")\n",
    "        averaged_discriminator_state_dict = torch.load(disc_fp + \"_averaged.pkl\")\n",
    "    except FileNotFoundError as e:\n",
    "        print(\"Model needs to be trained first. Please set TRAIN_MODEL to True.\")\n",
    "    \n",
    "    generator.load_state_dict(generator_state_dict)\n",
    "    discriminator.load_state_dict(discriminator_state_dict)\n",
    "    \n",
    "    averaged_generator.load_state_dict(averaged_generator_state_dict)\n",
    "    averaged_discriminator.load_state_dict(averaged_discriminator_state_dict)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 6. Validation\n",
    "\n",
    "We present some methods of validation for the provided generator."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "real_samples,     = next(iter(test_dataloader))\n",
    "generated_samples = averaged_generator(ts, batch_size)\n",
    "\n",
    "if subtract_start:\n",
    "    real_samples      = subtract_initial_point(real_samples)\n",
    "    generated_samples = subtract_initial_point(generated_samples)\n",
    "\n",
    "fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 10))\n",
    "\n",
    "with torch.no_grad():\n",
    "    np_tr_loss = tr_loss.cpu().numpy()\n",
    "\n",
    "ax1.plot(step_vec[:step], np_tr_loss[:step], alpha=1., color=\"dodgerblue\", label=\"training_loss\")\n",
    "\n",
    "if adversarial and (discriminator_type != \"wasserstein_cde\"):\n",
    "    ax12 = ax1.twinx()\n",
    "    for ii, sigma in enumerate(sigmas.T):\n",
    "        ax12.plot(sigma[:step], label=f\"discriminator_sigmas_{ii+1}\", alpha=0.75, color=\"grey\")\n",
    "    if discriminator_type == \"weighted_sigker\":\n",
    "        for jj, weight in enumerate(weights.T):\n",
    "            ax1.plot(weight[:step], label=f\"discriminator_weights_{jj+1}\", alpha=0.75)\n",
    "            \n",
    "make_grid(axis=ax1)\n",
    "extra_title = \" against discriminator sigma\" if adversarial else \"\"\n",
    "ax1.set_title(\"Training loss\" + extra_title)\n",
    "ax1.legend()\n",
    "\n",
    "if discriminator_type == \"stepped_sigker\":\n",
    "    bi_index = np.arange(steps)[_change_window + _burn_in:]\n",
    "    pct_av_loss_change = av_loss[_burn_in + _change_window:]/av_loss[_burn_in:-_change_window] - 1\n",
    "    ax12.plot(bi_index, pct_av_loss_change, alpha=0.5, color=\"grey\")\n",
    "    _ch_flag = True\n",
    "    for chgpt in scaling_chg:\n",
    "        ax12.axvline(chgpt, color=\"grey\", alpha=0.05, linestyle=\"dashed\")\n",
    "\n",
    "real_plot_samples = real_samples[..., 1:]\n",
    "with torch.no_grad():\n",
    "    generated_plot_samples = generated_samples.cpu()[..., 1:]\n",
    "    \n",
    "real_plot_samples      = real_plot_samples[:num_plot_samples]\n",
    "generated_plot_samples = generated_plot_samples[:num_plot_samples]\n",
    "real_first = True\n",
    "generated_first = True\n",
    "for real_sample_ in real_plot_samples:\n",
    "    kwargs = {'label': 'Real'} if real_first else {}\n",
    "    ax2.plot(ts.cpu(), real_sample_.cpu(), color='dodgerblue', linewidth=0.5, alpha=0.5, **kwargs)\n",
    "    real_first = False\n",
    "for generated_sample_ in generated_plot_samples:\n",
    "    kwargs = {'label': 'Generated'} if generated_first else {}\n",
    "    ax2.plot(ts.cpu(), generated_sample_.cpu(), color='crimson', linewidth=0.5, alpha=0.5, **kwargs)\n",
    "    generated_first = False\n",
    "ax2.legend()\n",
    "make_grid(axis=ax2)\n",
    "ax2.set_title(\"Real vs generated paths\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "num_plot_samples_ = min(batch_size, num_plot_samples)\n",
    "\n",
    "plot_results(ts, averaged_generator, test_dataloader, num_plot_samples_, plot_locs, subtract_start=subtract_start, figsize=(6, 18))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For more specific validation methods, see the notebook <code>unconditional_nsde_validation</code>."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
