{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Conditional Neural Processes (CNP) for 1D regression.\n",
    "[Conditional Neural Processes](https://arxiv.org/pdf/1807.01613.pdf) (CNPs) were\n",
    "introduced as a continuation of\n",
    "[Generative Query Networks](https://deepmind.com/blog/neural-scene-representation-and-rendering/)\n",
    "(GQN) to extend its training regime to tasks beyond scene rendering, e.g. to\n",
    "regression and classification."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch #torch==2.1.2\n",
    "import torch.nn as nn\n",
    "import matplotlib.pyplot as plt\n",
    "import datetime\n",
    "import numpy as np #numpy==1.24.3\n",
    "import torchsnooper #torchsnooper==0.8\n",
    "import plotting_utils_cnp as plotting\n",
    "import data_generator as data\n",
    "from matplotlib.backends.backend_pdf import PdfPages\n",
    "import pandas as pd #pandas==2.0.1\n",
    "import dask.dataframe as dd\n",
    "import sys\n",
    "sys.path.append('../utilities')\n",
    "import utilities as utils\n",
    "import import_ipynb\n",
    "import conditional_neural_process_model as cnp"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Running Conditional Neural Processes\n",
    "\n",
    "Now that we have defined the dataset as well as our model and its components we\n",
    "can start building everything into the graph. Before we get started we need to\n",
    "set some variables:\n",
    "\n",
    "*   **`TRAINING_ITERATIONS`** - a scalar that describes the number of iterations\n",
    "    for training. At each iteration we will sample a new batch of functions from\n",
    "    the GP, pick some of the points on the curves as our context points **(x,\n",
    "    y)<sub>C</sub>** and some points as our target points **(x,\n",
    "    y)<sub>T</sub>**. We will predict the mean and variance at the target points\n",
    "    given the context and use the log likelihood of the ground truth targets as\n",
    "    our loss to update the model.\n",
    "*   **`MAX_CONTEXT_POINTS`** - a scalar that sets the maximum number of contest\n",
    "    points used during training. The number of context points will then be a\n",
    "    value between 3 and `MAX_CONTEXT_POINTS` that is sampled at random for every\n",
    "    iteration.\n",
    "*   **`PLOT_AFTER`** - a scalar that regulates how often we plot the\n",
    "    intermediate results."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "TRAINING_ITERATIONS = int(3540) # Total number of training points: training_iterations * batch_size * max_content_points\n",
    "#BATCH_SIZE = 100 # number of simulation configurations\n",
    "\n",
    "MAX_CONTEXT_POINTS = 1000 # 2000 # 4000\n",
    "MAX_TARGET_POINTS =  2000 # 4000 # 8000\n",
    "CONTEXT_IS_SUBSET = True\n",
    "BATCH_SIZE = 1\n",
    "CONFIG_WISE = False\n",
    "PLOT_AFTER = int(200)\n",
    "torch.manual_seed(0)\n",
    "\n",
    "# all available x config/ physics parameters are [\"radius\",\"thickness\",\"npanels\",\"theta\",\"length\",\"height\",\"z_offset\",\"volume\",\"nC_Ge77\",\"time_0[ms]\",\"x_0[m]\",\"y_0[m]\",\"z_0[m]\",\"px_0[m]\",\"py_0[m]\",\"pz_0[m]\",\"ekin_0[eV]\",\"edep_0[eV]\",\"time_t[ms]\",\"x_t[m]\",\"y_t[m]\",\"z_t[m]\",\"px_t[m]\",\"py_t[m]\",\"pz_t[m]\",\"ekin_t[eV]\",\"edep_t[eV]\",\"nsec\"]\n",
    "# Comment: if using data version v1.1 for training, \"radius\",\"thickness\",\"npanels\",\"theta\",\"length\" is probably necessary\n",
    "names_x=[\"radius\",\"thickness\",\"npanels\",\"theta\",\"length\",\"r_0[m]\",\"z_0[m]\",\"time_t[ms]\",\"r_t[m]\",\"z_t[m]\",\"L_t[m]\",\"ln(E0vsET)\",\"edep_t[eV]\",\"nsec\"]\n",
    "name_y ='total_nC_Ge77[cts]'\n",
    "x_size = len(names_x)\n",
    "if isinstance(name_y,str):\n",
    "    y_size = 1\n",
    "else:\n",
    "    y_size = len(name_y)\n",
    "\n",
    "RATIO_TESTING_VS_TRAINING = 1/40\n",
    "version_cnp=\"v1.6\"\n",
    "version_lf=\"v1.4\"\n",
    "version_hf=\"v1.3\"\n",
    "path_to_files=f\"../simulation/out/LF/{version_lf}/tier2/\"\n",
    "path_out = f'./out/'\n",
    "f_out = f'{path_out}CNPGauss_{version_cnp}_{TRAINING_ITERATIONS}_c{MAX_CONTEXT_POINTS}_t{MAX_TARGET_POINTS}'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set data augmentation parameters\n",
    "USE_DATA_AUGMENTATION = \"mixup\" #\"smote\" #False #\"mixup\"\n",
    "USE_BETA = [0.1,0.1] # uniform => None, beta => [a,b] U-shape [0.1,0.1] Uniform [1.,1.] falling [0.2,0.5] rising [0.2,0.5]\n",
    "SIGNAL_TO_BACKGROUND_RATIO = \"\" # \"_1to4\" # used for smote augmentation\n",
    "\n",
    "if USE_DATA_AUGMENTATION:\n",
    "    path_out = f'./out/{USE_DATA_AUGMENTATION}/'\n",
    "    f_out = f'CNPGauss_{version_cnp}_{TRAINING_ITERATIONS}_c{MAX_CONTEXT_POINTS}_t{MAX_TARGET_POINTS}_{USE_DATA_AUGMENTATION}{SIGNAL_TO_BACKGROUND_RATIO}'\n",
    "    if USE_DATA_AUGMENTATION == \"mixup\":\n",
    "        path_to_files = f\"../simulation/out/LF/{version_lf}/tier3/beta_{USE_BETA[0]}_{USE_BETA[1]}/\"\n",
    "        f_out = f'CNPGauss_{version_cnp}_{TRAINING_ITERATIONS}_c{MAX_CONTEXT_POINTS}_t{MAX_TARGET_POINTS}_beta_{USE_BETA[0]}_{USE_BETA[1]}'\n",
    "    elif USE_DATA_AUGMENTATION == \"smote\" and CONFIG_WISE == True:\n",
    "        path_to_files = f\"../simulation/out/LF/{version_lf}/tier3/smote{SIGNAL_TO_BACKGROUND_RATIO}/\"\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "d_x, d_in, representation_size, d_out = x_size , x_size+y_size, 32, y_size+1\n",
    "encoder_sizes = [d_in, 32, 64, 128, 128, 128, 64, 48, representation_size]\n",
    "decoder_sizes = [representation_size + d_x, 32, 64, 128, 128, 128, 64, 48, d_out]\n",
    "\n",
    "model = cnp.DeterministicModel(encoder_sizes,decoder_sizes)\n",
    "model.load_state_dict(torch.load(f'./out/{f_out}_model.pth'))\n",
    "model.eval()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "mode=\"LF\"\n",
    "filelist = utils.get_all_files(f\"../simulation/out/{mode}/{version_lf}/tier2/neutron\")\n",
    "num_total_points = 50000\n",
    "\n",
    "MAX_CONTEXT_POINTS_NEW = int(1/3 * num_total_points)\n",
    "MAX_TARGET_POINTS_NEW = 2 * (MAX_CONTEXT_POINTS_NEW)\n",
    "bce = nn.BCELoss()\n",
    "\n",
    "x_lf = np.empty([0,6])\n",
    "sum_target_y_lf = np.empty([0,1])\n",
    "mean_mu_cnp_lf = np.empty([0,1])\n",
    "mean_sigma_cnp_lf = np.empty([0,1])\n",
    "rGe77_lf = np.empty([0,1])\n",
    "totGe77_lf = np.empty([0,1])\n",
    "totGe77_hf = np.empty([0,1])\n",
    "hist_target_sig_lf = hist_target_bkg_lf = hist_pred_sig_lf = hist_pred_bkg_lf = np.zeros(100)\n",
    "fout = open(f'{path_out}{f_out}_training.txt', \"a\")\n",
    "\n",
    "# create a PdfPages object\n",
    "pdf = PdfPages(f'{path_out}{f_out}_result_{mode}.pdf')\n",
    "\n",
    "for i,file in enumerate(filelist):\n",
    "    \n",
    "    path_to_files = file[:-4]\n",
    "    dataset_config = data.DataGeneration(num_iterations=1, num_context_points=MAX_CONTEXT_POINTS_NEW, num_target_points=MAX_TARGET_POINTS_NEW, batch_size = 1, use_data_augmentation=\"None\", path_to_files=path_to_files,x_size=x_size,y_size=y_size, mode = \"config\", ratio_testing=0.,names_x=names_x, name_y=name_y)\n",
    "    data_config = dataset_config.get_data(0, CONTEXT_IS_SUBSET)\n",
    "    # Get the predicted mean and variance at the target points for the testing set\n",
    "    log_prob_config, mu_config, sigma_config = model(data_config.query, data_config.target_y)\n",
    "    # Define the loss\n",
    "    config_loss = -log_prob_config.mean()\n",
    "    if max(mu_config[0].detach().numpy()) <= 1 and min(mu_config[0].detach().numpy()) >= 0:\n",
    "            loss_bce_config = bce(mu_config,  data_config.target_y)\n",
    "    else:\n",
    "            loss_bce_config = -1.\n",
    "\n",
    "    mu_config = mu_config[0].detach().numpy()\n",
    "    target_y = data_config.target_y[0].detach().numpy()\n",
    "    df = pd.read_csv(file, index_col=0)\n",
    "    tmp = df[[\"fidelity\",\"radius\",\"thickness\",\"npanels\",\"theta\",\"length\"]].to_numpy()\n",
    "    x_lf         = np.append(x_lf,[df[[\"fidelity\",\"radius\",\"thickness\",\"npanels\",\"theta\",\"length\"]].to_numpy()[0]],axis=0)\n",
    "\n",
    "    sum_target_y_tmp = np.array([np.sum(target_y)])\n",
    "    sum_target_y_lf    = np.append(sum_target_y_lf, [sum_target_y_tmp], axis=0)\n",
    "    mean_mu_tmp = np.array([np.mean(mu_config)])\n",
    "    mean_mu_cnp_lf = np.append(mean_mu_cnp_lf, [mean_mu_tmp], axis=0)\n",
    "    mean_sigma_tmp = np.array([np.mean(sigma_config[0].detach().numpy())])\n",
    "    mean_sigma_cnp_lf = np.append(mean_sigma_cnp_lf, [mean_sigma_tmp], axis=0)\n",
    "    rGe77_lf = np.append(rGe77_lf,[np.array([np.sum(pd.read_csv(file)[\"prod_rate_Ge77[nuc/(kg*yr)]\"].to_numpy())])], axis=0)\n",
    "    totGe77_lf = np.append(totGe77_lf,[np.array([np.sum(pd.read_csv(file)[\"total_nC_Ge77[cts]\"].to_numpy())])], axis=0)\n",
    "\n",
    "    hist_target_sig2, hist_target_bkg2, hist_pred_sig2, hist_pred_bkg2 = plotting.sum_hist(mu_config, target_y, hist_target_sig_lf, hist_target_bkg_lf, hist_pred_sig_lf, hist_pred_bkg_lf)\n",
    "\n",
    "    print(\"{}/{} {}, {}, radius: {} cm, test loss: {} (bce {})\".format(i,len(filelist),datetime.datetime.now().strftime(\"%Y-%m-%d %H:%M:%S\"), mode,x_lf[-1,1], config_loss, loss_bce_config))\n",
    "    fout.write(\"{}, Iteration: {}, test loss: {} (bce {})\\n\".format(datetime.datetime.now().strftime(\"%Y-%m-%d %H:%M:%S\"), i, config_loss, loss_bce_config))\n",
    "    fig = plotting.plot_result_configwise(mu_config, target_y, f'{config_loss:.2f}', x_lf[-1][1:])\n",
    "    pdf.savefig(fig)\n",
    "\n",
    "    if i==15:\n",
    "        hist_target_sig2= hist_target_sig_lf\n",
    "        hist_target_bkg2 = hist_target_bkg_lf\n",
    "        hist_pred_sig2 = hist_pred_sig_lf\n",
    "        hist_pred_bkg2 = hist_pred_bkg_lf\n",
    "        plt.show()\n",
    "    plt.clf()\n",
    "    \n",
    "fig1 = plotting.plot_result_summed(hist_target_sig_lf, hist_target_bkg_lf, hist_pred_sig_lf, hist_pred_bkg_lf)\n",
    "pdf.savefig(fig1)\n",
    "#plt.show()\n",
    "plt.clf()\n",
    "pdf.close()\n",
    "\n",
    "fout.close()\n",
    "\n",
    "df = pd.DataFrame(x_lf, columns=[\"Mode\",\"Radius[cm]\",\"Thickness[cm]\",\"NPanels\",\"Theta[deg]\",\"Length[cm]\"])\n",
    "df['Ge-77[nevents]'] = sum_target_y_lf\n",
    "df['Ge-77_CNP'] = mean_mu_cnp_lf\n",
    "df['Ge-77_CNP_err'] = mean_sigma_cnp_lf\n",
    "df['rGe77[nuc/(kg*yr)]'] = rGe77_lf\n",
    "df=df.round(decimals=4)\n",
    "df.to_csv(f'{path_out}{f_out}_Ge77rates.csv')\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "mode=\"HF\"\n",
    "\n",
    "filelist = utils.get_all_files(f\"../simulation/out/{mode}/{version_hf}/tier2/neutron\")\n",
    "\n",
    "x_hf = np.empty([0,6])\n",
    "sum_target_y_hf = np.empty([0,1])\n",
    "mean_mu_cnp_hf = np.empty([0,1])\n",
    "mean_sigma_cnp_hf = np.empty([0,1])\n",
    "upper_lim_hf = np.empty([0,1])\n",
    "rGe77_hf = np.empty([0,1])\n",
    "totGe77_hf = np.empty([0,1])\n",
    "hist_target_sig_hf = hist_target_bkg_hf = hist_pred_sig_hf = hist_pred_bkg_hf = np.zeros(100)\n",
    "fout = open(f'{path_out}{f_out}_training.txt', \"a\")\n",
    "\n",
    "# create a PdfPages object\n",
    "pdf=PdfPages(f'{path_out}{f_out}_result_{mode}.pdf')\n",
    "\n",
    "for i,file in enumerate(filelist):\n",
    "\n",
    "    path_to_files = file[:-4]\n",
    "    num_total_points = 0\n",
    "    with open(file, \"rbU\") as f:\n",
    "        num_total_points += int(np.floor(sum(1 for _ in f)))\n",
    "\n",
    "    MAX_CONTEXT_POINTS_NEW = int(1/3 * (num_total_points-1))\n",
    "    MAX_TARGET_POINTS_NEW = 2 * MAX_CONTEXT_POINTS_NEW\n",
    "\n",
    "    dataset_config = data.DataGeneration(num_iterations=1, num_context_points=MAX_CONTEXT_POINTS_NEW, num_target_points=MAX_TARGET_POINTS_NEW, batch_size = 1, use_data_augmentation=\"None\", path_to_files=path_to_files,x_size=x_size,y_size=y_size, mode = \"config\", ratio_testing=0.,names_x=names_x, name_y=name_y)\n",
    "    data_config = dataset_config.get_data(0, CONTEXT_IS_SUBSET)\n",
    "    \n",
    "    # Get the predicted mean and variance at the target points for the testing set\n",
    "    log_prob_config, mu_config, sigma_config = model(data_config.query, data_config.target_y)\n",
    "    # Define the loss\n",
    "    config_loss = -log_prob_config.mean()\n",
    "    if max(mu_config[0].detach().numpy()) <= 1 and min(mu_config[0].detach().numpy()) >= 0:\n",
    "            loss_bce_config = bce(mu_config,  data_config.target_y)\n",
    "    else:\n",
    "            loss_bce_config = -1.\n",
    "\n",
    "    mu_config = mu_config[0].detach().numpy()\n",
    "    \n",
    "    target_y = data_config.target_y[0].detach().numpy()\n",
    "    df = pd.read_csv(file, index_col=0)\n",
    "    x_hf         = np.append(x_hf,[df[[\"fidelity\",\"radius\",\"thickness\",\"npanels\",\"theta\",\"length\"]].to_numpy()[0]],axis=0)\n",
    "    #x         = np.append(x,[data_config.query[1][0][0][-5:].numpy()],axis=0)\n",
    "    sum_target_y_tmp = np.array([np.sum(target_y)])\n",
    "    sum_target_y_hf    = np.append(sum_target_y_hf, [sum_target_y_tmp], axis=0)\n",
    "    mean_mu_tmp = np.array([np.mean(mu_config)])\n",
    "    upper_lim_hf = np.append(upper_lim_hf,[np.array([np.percentile(mu_config,95.)])], axis=0)\n",
    "    mean_mu_cnp_hf = np.append(mean_mu_cnp_hf, [mean_mu_tmp], axis=0)\n",
    "    mean_sigma_tmp = np.array([np.mean(sigma_config[0].detach().numpy())])\n",
    "    mean_sigma_cnp_hf = np.append(mean_sigma_cnp_hf, [mean_sigma_tmp], axis=0)\n",
    "    rGe77_hf = np.append(rGe77_hf,[np.array([np.sum(pd.read_csv(file)[\"prod_rate_Ge77[nuc/(kg*yr)]\"].to_numpy())])], axis=0)\n",
    "    totGe77_hf = np.append(totGe77_hf,[np.array([np.sum(pd.read_csv(file)[\"total_nC_Ge77[cts]\"].to_numpy())])], axis=0)\n",
    "\n",
    "    hist_target_sig_hf, hist_target_bkg_hf, hist_pred_sig_hf, hist_pred_bkg_hf = plotting.sum_hist(mu_config, target_y, hist_target_sig_hf, hist_target_bkg_hf, hist_pred_sig_hf, hist_pred_bkg_hf)\n",
    "    print(\"{}/{} {}, {}, radius: {} cm, test loss: {} (bce {})\".format(i,len(filelist),datetime.datetime.now().strftime(\"%Y-%m-%d %H:%M:%S\"), mode,x_hf[-1,1], config_loss, loss_bce_config))\n",
    "    fout.write(\"{}, Iteration: {}, test loss: {} (bce {})\\n\".format(datetime.datetime.now().strftime(\"%Y-%m-%d %H:%M:%S\"), i, config_loss, loss_bce_config))\n",
    "   \n",
    "    fig = plotting.plot_result_configwise(mu_config, target_y, f'{config_loss:.2f}', x_hf[-1][1:])\n",
    "    pdf.savefig(fig)\n",
    "    #plt.show()\n",
    "    plt.clf()\n",
    "\n",
    "    \n",
    "fig1 = plotting.plot_result_summed(hist_target_sig_hf, hist_target_bkg_hf, hist_pred_sig_hf, hist_pred_bkg_hf)\n",
    "pdf.savefig(fig1)\n",
    "plt.show()\n",
    "plt.clf()\n",
    "pdf.close()\n",
    "\n",
    "fout.close()\n",
    "\n",
    "df= pd.read_csv(f'{path_out}{f_out}_Ge77rates.csv', index_col=0)\n",
    "x = df[[\"Mode\",\"Radius[cm]\",\"Thickness[cm]\",\"NPanels\",\"Theta[deg]\",\"Length[cm]\",\"Ge-77[nevents]\",\"Ge-77_CNP\",\"Ge-77_CNP_err\",\"rGe77[nuc/(kg*yr)]\"]].to_numpy()\n",
    "x_tmp = np.append(x_hf, sum_target_y_hf, axis=1)\n",
    "x_tmp = np.append(x_tmp, mean_mu_cnp_hf, axis=1)\n",
    "x_tmp = np.append(x_tmp, mean_sigma_cnp_hf, axis=1)\n",
    "x_tmp = np.append(x_tmp, rGe77_hf, axis=1)\n",
    "x = np.append(x, x_tmp, axis=0)\n",
    "df = pd.DataFrame(x, columns=df.columns)\n",
    "\n",
    "df=df.round(decimals=4)\n",
    "df.to_csv(f'{path_out}{f_out}_Ge77rates.csv')\n",
    "df.to_csv(f'../multi-fidelity-gaussian-process/in/Ge77_rates_CNP_{version_cnp}.csv')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle as pkl"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "color_sim_hf=\"black\"\n",
    "color_cnp_hf=\"black\"\n",
    "color_sim_lf=\"red\"\n",
    "color_cnp_lf=\"teal\"\n",
    "\n",
    "figx = pkl.load(  open(f'./out/{f_out}_distr.p',  'rb')  )\n",
    "\n",
    "y1=[]\n",
    "x2=[]\n",
    "x1=[]\n",
    "y2=[]\n",
    "w1=[]\n",
    "w2=[]\n",
    "p = figx.axes[0].patches  # There are 10 patches\n",
    "for i in range(int(len(p)/2)):\n",
    "    x1.append(p[i].get_xy()[0])\n",
    "    y1.append(p[i].get_height())\n",
    "    w1.append(p[i].get_width())\n",
    "    x2.append(p[i+int(len(p)/2)].get_xy()[0])\n",
    "    y2.append(p[i+int(len(p)/2)].get_height())\n",
    "    w2.append(p[i+int(len(p)/2)].get_width())\n",
    "figx=[]\n",
    "p=[]\n",
    "\n",
    "xmin=[0,0,0,0,0]\n",
    "xmax=[265,20,360,90,150]\n",
    "indices = [0,2,1,3,4]\n",
    "indices_l = [\"(c)\",\"(d)\",\"(e)\",\"(f)\",\"(g)\"]\n",
    "xlabels=[\"Radius [cm]\",\"Thickness [cm]\",\"N Panels\",r\"Angle $\\varphi$ [deg]\",\"Length [cm]\"]\n",
    "\n",
    "fig = plt.figure(figsize=(18,6), layout = \"constrained\")\n",
    "gs0=fig.add_gridspec(1,2, width_ratios = [2,1])\n",
    "\n",
    "gs00 = gs0[0].subgridspec(2,2)\n",
    "gs01 = gs0[1].subgridspec(3,1)\n",
    "for l in range(2):\n",
    "    for j in range(2):\n",
    "        ax = fig.add_subplot(gs00[l,j])\n",
    "        if l== 0 and j==0:\n",
    "            ax.bar(x=x1,height=y1,width=w1, color=(113/255,150/255,159/255), alpha=0.8)\n",
    "            ax.bar(x=x2,height=y2,width=w2,color='coral', alpha=0.8)\n",
    "            ax.set_xlabel(r\"$y_{CNP}$\", fontsize=10)\n",
    "            ax.set_ylabel(r'Count', fontsize=10)\n",
    "            ax.set_yscale(\"log\")\n",
    "            #ax.set_ylim(0.1,1000000)\n",
    "            ax.text(.01, .99, '(a)', ha='left', va='top', transform=ax.transAxes)\n",
    "        if l == 0 and j==1:\n",
    "            nbins = len(hist_target_sig2)\n",
    "            range2 = [0.0, 1.0]\n",
    "            bin_length = (range2[1]-range2[0])/nbins\n",
    "            bins = np.arange(range2[0],range2[1]+bin_length, bin_length)\n",
    "            centroids = (bins[1:] + bins[:-1]) / 2\n",
    "\n",
    "            #ax.hist(centroids, weights = hist_target_bkg2, range=range2, bins=nbins, color=(3/255,37/255,46/255), alpha=0.8, label='Background (Label)')\n",
    "            ax.hist(centroids, weights = hist_pred_bkg2, range=range2, bins=nbins, color=(113/255,150/255,159/255), alpha=0.8, label='Background (CNP)')\n",
    "            #ax.hist(centroids, weights = hist_target_sig2, range=range2, bins=nbins, color='orangered', alpha=1.0, label='Signal (Label)')\n",
    "            ax.hist(centroids, weights = hist_pred_sig2, range=range2, bins=nbins, color='coral', alpha=0.8, label='Signal (CNP)')\n",
    "            ax.set_xlabel(r\"$y_{CNP}$\", fontsize=10)\n",
    "            ax.set_ylabel(r'Count', fontsize=10)\n",
    "            ax.set_yscale(\"log\")\n",
    "            #ax.set_ylim(0.1,1000000)\n",
    "            ax.text(.01, .99, '(b)', ha='left', va='top', transform=ax.transAxes)\n",
    "            ax.legend(loc=9, bbox_to_anchor=(0.795,1.), ncol=1,fontsize=10)\n",
    "\n",
    "        if l == 1:\n",
    "            i=j\n",
    "            plt.plot(x_lf[:,indices[i]+1],rGe77_lf,\"o\",markersize=2, color=color_sim_lf, alpha=0.3, label=\"LF (raw)\")\n",
    "            #plt.plot(x_hf[:,indices[i]+1],totGe77_hf,\">\",markersize=4, color=color_sim_hf, label=\"HF (raw)\")\n",
    "            ax.set_xlim(xmin[indices[i]],xmax[indices[i]])#\n",
    "            ax.set_xlabel(xlabels[indices[i]], fontsize=10)\n",
    "            ax.set_ylabel(r'$y_{raw}$',color=color_sim_lf, fontsize=10)\n",
    "            plt.tick_params(axis='y', labelcolor=color_sim_lf)\n",
    "            handles = plt.gca().get_legend_handles_labels()[0]\n",
    "            labels = plt.gca().get_legend_handles_labels()[1]\n",
    "\n",
    "            ax = plt.twinx()\n",
    "            plt.errorbar(x_lf[:,indices[i]+1], mean_mu_cnp_lf[:,0], yerr=mean_sigma_cnp_lf[:,0],fmt='o',markersize=2, elinewidth=0.5, color=color_cnp_lf, label=\"LF (CNP)\")\n",
    "            #plt.errorbar(x_hf[:,indices[i]+1], mean_mu_cnp_hf[:,0], yerr=mean_sigma_cnp_hf[:,0],fmt='s',markersize=4, color=color_cnp_hf, label=\"HF (CNP)\")\n",
    "            plt.tick_params(axis='y', labelcolor=color_cnp_lf)\n",
    "\n",
    "            for t in range(len(plt.gca().get_legend_handles_labels()[1])):\n",
    "                handles.append(plt.gca().get_legend_handles_labels()[0][t])\n",
    "                labels.append(plt.gca().get_legend_handles_labels()[1][t])\n",
    "\n",
    "            ax.set_ylabel(r'$y_{CNP}$',color=color_cnp_lf, fontsize=10)\n",
    "            ax.text(.01, .99,indices_l[i], ha='left', va='top', transform=ax.transAxes)\n",
    "            if j==0:\n",
    "                plt.legend(handles,labels,loc=9, bbox_to_anchor=(0.14,0.25),ncol=1)\n",
    "\n",
    "for j in range(3):\n",
    "    ax = fig.add_subplot(gs01[j])\n",
    "    i=j+2\n",
    "    plt.plot(x_lf[:,indices[i]+1],rGe77_lf,'o',markersize=2, color=color_sim_lf,alpha=0.3, label=\"LF (raw)\")\n",
    "    #plt.plot(x_hf[:,indices[i]+1],totGe77_hf,'s',markersize=4, color=color_sim_hf, label=\"HF (raw)\")\n",
    "    ax.set_xlim(xmin[indices[i]],xmax[indices[i]])#\n",
    "    ax.set_xlabel(xlabels[indices[i]], fontsize=10)\n",
    "    ax.set_ylabel(r'$y_{raw}$',color=color_sim_lf, fontsize=10)\n",
    "    \n",
    "    plt.tick_params(axis='y', labelcolor=color_sim_lf)\n",
    "    \n",
    "    ax = plt.twinx()\n",
    "    #plt.plot(x_lf[:,indices[i]+1], mean_mu_cnp_lf[:,0],'o',markersize=2, color=color_cnp_lf, label=\"LF (CNP)\")\n",
    "    plt.errorbar(x_lf[:,indices[i]+1], mean_mu_cnp_lf[:,0], yerr=mean_sigma_cnp_lf[:,0],fmt='o',markersize=2, elinewidth=0.5, color=color_cnp_lf, label=\"LF (CNP)\")\n",
    "    #plt.errorbar(x_hf[:,indices[i]+1], mean_mu_cnp_hf[:,0], yerr=mean_sigma_cnp_hf[:,0],fmt='>',markersize=4, color=color_cnp_hf, label=\"HF (CNP)\")\n",
    "    plt.tick_params(axis='y', labelcolor=color_cnp_lf)\n",
    "    ax.set_ylabel(r'$y_{CNP}$',color=color_cnp_lf, fontsize=10)\n",
    "    ax.text(.01, .99,indices_l[i], ha='left', va='top', transform=ax.transAxes)\n",
    "fig.savefig(f'{path_out}{f_out}_result.png')\n",
    "\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
