{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Conditional Neural Processes (CNP) for 1D regression.\n",
    "[Conditional Neural Processes](https://arxiv.org/pdf/1807.01613.pdf) (CNPs) were\n",
    "introduced as a continuation of\n",
    "[Generative Query Networks](https://deepmind.com/blog/neural-scene-representation-and-rendering/)\n",
    "(GQN) to extend its training regime to tasks beyond scene rendering, e.g. to\n",
    "regression and classification."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch #torch==2.1.2\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "import matplotlib.pyplot as plt\n",
    "import datetime\n",
    "import numpy as np #numpy==1.24.3\n",
    "import torchsnooper #torchsnooper==0.8\n",
    "import os\n",
    "import plotting_utils_cnp as plotting\n",
    "import data_generator as data\n",
    "from matplotlib.backends.backend_pdf import PdfPages\n",
    "import pandas as pd #pandas==2.0.1\n",
    "import dask.dataframe as dd\n",
    "import import_ipynb\n",
    "import conditional_neural_process_model as cnp\n",
    "import pickle as pkl\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<img src=\"../utilities/concept.png\" alt=\"drawing\" width=\"500\"/>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Running Conditional Neural Processes\n",
    "\n",
    "Now that we have defined the dataset as well as our model and its components we\n",
    "can start building everything into the graph. Before we get started we need to\n",
    "set some variables:\n",
    "\n",
    "*   **`TRAINING_ITERATIONS`** - a scalar that describes the number of iterations\n",
    "    for training. At each iteration we will sample a new batch of functions from\n",
    "    the GP, pick some of the points on the curves as our context points **(x,\n",
    "    y)<sub>C</sub>** and some points as our target points **(x,\n",
    "    y)<sub>T</sub>**. We will predict the mean and variance at the target points\n",
    "    given the context and use the log likelihood of the ground truth targets as\n",
    "    our loss to update the model.\n",
    "*   **`MAX_CONTEXT_POINTS`** - a scalar that sets the maximum number of contest\n",
    "    points used during training. The number of context points will then be a\n",
    "    value between 3 and `MAX_CONTEXT_POINTS` that is sampled at random for every\n",
    "    iteration.\n",
    "*   **`PLOT_AFTER`** - a scalar that regulates how often we plot the\n",
    "    intermediate results."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "TRAINING_ITERATIONS = int(3540) # Total number of training points: training_iterations * batch_size * max_content_points\n",
    "#BATCH_SIZE = 100 # number of simulation configurations\n",
    "\n",
    "MAX_CONTEXT_POINTS = 1000 # 2000 # 4000\n",
    "MAX_TARGET_POINTS =  2000 # 4000 # 8000\n",
    "CONTEXT_IS_SUBSET = True\n",
    "BATCH_SIZE = 1\n",
    "CONFIG_WISE = False\n",
    "PLOT_AFTER = int(500)\n",
    "torch.manual_seed(0)\n",
    "\n",
    "# all available x config/ physics parameters are [\"radius\",\"thickness\",\"npanels\",\"theta\",\"length\",\"height\",\"z_offset\",\"volume\",\"nC_Ge77\",\"time_0[ms]\",\"x_0[m]\",\"y_0[m]\",\"z_0[m]\",\"px_0[m]\",\"py_0[m]\",\"pz_0[m]\",\"ekin_0[eV]\",\"edep_0[eV]\",\"time_t[ms]\",\"x_t[m]\",\"y_t[m]\",\"z_t[m]\",\"px_t[m]\",\"py_t[m]\",\"pz_t[m]\",\"ekin_t[eV]\",\"edep_t[eV]\",\"nsec\"]\n",
    "# Comment: if using data version v1.1 for training, \"radius\",\"thickness\",\"npanels\",\"theta\",\"length\" is probably necessary\n",
    "names_x=[\"radius\",\"thickness\",\"npanels\",\"theta\",\"length\",\"r_0[m]\",\"z_0[m]\",\"time_t[ms]\",\"r_t[m]\",\"z_t[m]\",\"L_t[m]\",\"ln(E0vsET)\",\"edep_t[eV]\",\"nsec\"]\n",
    "name_y ='total_nC_Ge77[cts]'\n",
    "x_size = len(names_x)\n",
    "if isinstance(name_y,str):\n",
    "    y_size = 1\n",
    "else:\n",
    "    y_size = len(name_y)\n",
    "\n",
    "RATIO_TESTING_VS_TRAINING = 1/40\n",
    "version_cnp=\"v1.6\"\n",
    "version_lf=\"v1.4\"\n",
    "\n",
    "path_to_files=f\"../simulation/out/LF/{version_lf}/tier2/\"\n",
    "path_out = f'./out/'\n",
    "f_out = f'{path_out}CNPGauss_{version_cnp}_{TRAINING_ITERATIONS}_c{MAX_CONTEXT_POINTS}_t{MAX_TARGET_POINTS}'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Data augmentation methods used:\n",
    "\n",
    "<img src=\"../utilities/data_augmentation.png\" alt=\"drawing\" width=\"800\"/>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set data augmentation parameters\n",
    "USE_DATA_AUGMENTATION = \"mixup\" #\"smote\" #False #\"mixup\"\n",
    "USE_BETA = [0.1,0.1] # uniform => None, beta => [a,b] U-shape [0.1,0.1] Uniform [1.,1.] falling [0.2,0.5] rising [0.2,0.5]\n",
    "SIGNAL_TO_BACKGROUND_RATIO = \"\" # \"_1to4\" # used for smote augmentation\n",
    "\n",
    "if USE_DATA_AUGMENTATION:\n",
    "    path_out = f'./out/{USE_DATA_AUGMENTATION}/'\n",
    "    f_out = f'CNPGauss_{version_cnp}_{TRAINING_ITERATIONS}_c{MAX_CONTEXT_POINTS}_t{MAX_TARGET_POINTS}_{USE_DATA_AUGMENTATION}{SIGNAL_TO_BACKGROUND_RATIO}'\n",
    "    if USE_DATA_AUGMENTATION == \"mixup\":\n",
    "        path_to_files = f\"../simulation/out/LF/{version_lf}/tier3/beta_{USE_BETA[0]}_{USE_BETA[1]}/\"\n",
    "        f_out = f'CNPGauss_{version_cnp}_{TRAINING_ITERATIONS}_c{MAX_CONTEXT_POINTS}_t{MAX_TARGET_POINTS}_beta_{USE_BETA[0]}_{USE_BETA[1]}'\n",
    "    elif USE_DATA_AUGMENTATION == \"smote\" and CONFIG_WISE == True:\n",
    "        path_to_files = f\"../simulation/out/LF/{version_lf}/tier3/smote{SIGNAL_TO_BACKGROUND_RATIO}/\"\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Train dataset\n",
    "dataset_train = data.DataGeneration(num_iterations=TRAINING_ITERATIONS, num_context_points=MAX_CONTEXT_POINTS, num_target_points=MAX_TARGET_POINTS, batch_size = BATCH_SIZE, config_wise=CONFIG_WISE, path_to_files=path_to_files,x_size=x_size,y_size=y_size, mode = \"training\", ratio_testing=RATIO_TESTING_VS_TRAINING,sig_bkg_ratio = SIGNAL_TO_BACKGROUND_RATIO, use_data_augmentation=USE_DATA_AUGMENTATION, names_x = names_x, name_y=name_y)\n",
    "TRAINING_ITERATIONS = dataset_train._num_iterations\n",
    "# Testing dataset\n",
    "dataset_testing = data.DataGeneration(num_iterations=int(np.round(TRAINING_ITERATIONS/PLOT_AFTER))+5, num_context_points=MAX_CONTEXT_POINTS, num_target_points=MAX_TARGET_POINTS, batch_size = 1, config_wise=False, path_to_files=f\"../simulation/out/LF/{version_lf}/tier2/\",x_size=x_size,y_size=y_size, mode = \"testing\",ratio_testing=RATIO_TESTING_VS_TRAINING, sig_bkg_ratio = SIGNAL_TO_BACKGROUND_RATIO, use_data_augmentation=\"None\", names_x = names_x, name_y=name_y)\n",
    "TRAINING_ITERATIONS = dataset_train._num_iterations if TRAINING_ITERATIONS > dataset_train._num_iterations else TRAINING_ITERATIONS\n",
    "PLOT_AFTER =  int(5 * np.ceil(np.ceil(TRAINING_ITERATIONS/(dataset_testing._num_iterations-2))/5)) if PLOT_AFTER < int(np.ceil(TRAINING_ITERATIONS/(dataset_testing._num_iterations-2))) else PLOT_AFTER\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can now add the model to the graph and finalise it by defining the train step\n",
    "and the initializer."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "d_x, d_in, representation_size, d_out = x_size , x_size+y_size, 32, y_size+1\n",
    "encoder_sizes = [d_in, 32, 64, 128, 128, 128, 64, 48, representation_size]\n",
    "decoder_sizes = [representation_size + d_x, 32, 64, 128, 128, 128, 64, 48, d_out]\n",
    "\n",
    "model = cnp.DeterministicModel(encoder_sizes, decoder_sizes)\n",
    "\n",
    "optimizer = optim.Adam(model.parameters(), lr=1e-4)\n",
    "#optimizer = torch.optim.SGD(model.parameters(), lr=0.0001, momentum=0.9)\n",
    "# \n",
    "\n",
    "bce = nn.BCELoss()\n",
    "iter_testing = 0\n",
    "fout = open(f'{path_out}{f_out}_training.txt', \"w\")\n",
    "\n",
    "# create a PdfPages object\n",
    "pdf = PdfPages(f'{path_out}{f_out}_training.pdf')\n",
    "\n",
    "for it in range(TRAINING_ITERATIONS):\n",
    "    # load data:\n",
    "    data_train = dataset_train.get_data(it, CONTEXT_IS_SUBSET)\n",
    "\n",
    "    # Get the predicted mean and variance at the target points for the testing set\n",
    "    log_prob, mu, _ = model(data_train.query, data_train.target_y)\n",
    "    \n",
    "    # Define the loss\n",
    "    loss = -log_prob.mean()\n",
    "    loss.backward()\n",
    "\n",
    "    # Perform gradient descent to update parameters\n",
    "    optimizer.step()\n",
    "    \n",
    "    # reset gradient to 0 on all parameters\n",
    "    optimizer.zero_grad()\n",
    "\n",
    "    if max(mu[0].detach().numpy()) <= 1 and min(mu[0].detach().numpy()) >= 0:\n",
    "        loss_bce = bce(mu, data_train.target_y)\n",
    "    else:\n",
    "        loss_bce = -1.\n",
    "\n",
    "    mu=mu[0].detach().numpy()\n",
    "    if it % 500 == 0 or it > 3400:\n",
    "        print('{} Iteration: {}/{}, train loss: {:.4f} (vs BCE {:.4f})'.format(datetime.datetime.now().strftime(\"%Y-%m-%d %H:%M:%S\"),it, TRAINING_ITERATIONS,loss, loss_bce))\n",
    "        fout.write('Iteration: {}/{}, train loss: {:.4f} (vs BCE {:.4f})\\n'.format(it, TRAINING_ITERATIONS,loss, loss_bce))\n",
    "    \n",
    "    if it % PLOT_AFTER == 0 or it == int(TRAINING_ITERATIONS-1) or it > 3500:\n",
    "        data_testing = dataset_testing.get_data(iter_testing, CONTEXT_IS_SUBSET)\n",
    "        log_prob_testing, mu_testing, _ = model(data_testing.query, data_testing.target_y)\n",
    "        # Define the loss\n",
    "        loss_testing = -log_prob_testing.mean()\n",
    "\n",
    "        if max(mu_testing[0].detach().numpy()) <= 1 and min(mu_testing[0].detach().numpy()) >= 0:\n",
    "            loss_bce_testing = bce(mu_testing,  data_testing.target_y)\n",
    "        else:\n",
    "            loss_bce_testing = -1.\n",
    "\n",
    "        mu_testing=mu_testing[0].detach().numpy()\n",
    "        print(\"{}, Iteration: {}, test loss: {:.4f} (vs BCE {:.4f})\".format(datetime.datetime.now().strftime(\"%Y-%m-%d %H:%M:%S\"), it, loss_testing, loss_bce_testing))\n",
    "        fout.write(\"{}, Iteration: {}, test loss: {:.4f} (vs BCE {:.4f})\\n\".format(datetime.datetime.now().strftime(\"%Y-%m-%d %H:%M:%S\"), it, loss_testing, loss_bce_testing))\n",
    "        if isinstance(name_y,str):\n",
    "            fig = plotting.plot(mu, data_train.target_y[0].detach().numpy(), f'{loss:.2f}', mu_testing, data_testing.target_y[0].detach().numpy(), f'{loss_testing:.2f}', it)\n",
    "        else:\n",
    "            for k in range(y_size):\n",
    "                fig = plotting.plot(mu[:,k], data_train.target_y[0].detach().numpy()[:,k], f'{loss:.2f}', mu_testing[:,k], data_testing.target_y[0].detach().numpy()[:,k], f'{loss_testing:.2f}', it)\n",
    "        #if it % PLOT_AFTER*5 == 0 or it == int(TRAINING_ITERATIONS-1) or it > 3500:\n",
    "        if it % PLOT_AFTER*5 == 0 or it == int(TRAINING_ITERATIONS-1):\n",
    "            pdf.savefig(fig)\n",
    "            pkl.dump( fig,  open(f'./out/{f_out}_distr.p',  'wb')  )\n",
    "            plt.show()\n",
    "            plt.clf()\n",
    "        iter_testing += 1\n",
    "pdf.close()\n",
    "fout.close()\n",
    "torch.save(model.state_dict(), f'./out/{f_out}_model.pth')\n",
    "\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
