{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "56ef9d85-1fb3-4696-a632-7d72c8f4fd4c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# import packages\n",
    "import scipy\n",
    "import torch\n",
    "from torch.utils.data import TensorDataset, DataLoader\n",
    "from torch.optim import Adam, lr_scheduler\n",
    "\n",
    "from IPython.display import display, Math, Latex\n",
    "\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from matplotlib import cm\n",
    "import matplotlib as mpl\n",
    "from tqdm import tqdm, trange\n",
    "import pandas as pd\n",
    "import ot\n",
    "import os\n",
    "\n",
    "import functions as func\n",
    "import diffusion as diff\n",
    "import synthetic_data as synt\n",
    "\n",
    "import seaborn as sns\n",
    "sns.set_theme()\n",
    "import pickle\n",
    "import time\n",
    "\n",
    "import sampler as sp\n",
    "import decoder\n",
    "from functions import gaussian, empirical, kl_divergence, kl, wasserstein_w2, w2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "c0b871a2-1565-4fb4-81da-f067a3cdb73a",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "6eb818e9-fb8d-4fc4-9d83-877bf80c9d84",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Utilisation du CPU\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<torch._C.Generator at 0x104c92930>"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#devices and generator seeds\n",
    "\n",
    "device_cuda = torch.device(\"cuda:0\")\n",
    "device_cpu = torch.device(\"cpu\")\n",
    "\n",
    "if torch.cuda.is_available() and torch.backends.cuda.is_built():\n",
    "    print(\"Utilisation du GPU\")\n",
    "    device = device_cuda\n",
    "else:\n",
    "    print(\"Utilisation du CPU\")\n",
    "    device = device_cpu\n",
    "device = torch.device(\"cpu\")\n",
    "\n",
    "#seed = torch.random.seed()\n",
    "#print(\"seed:\", seed)\n",
    "seed = 20\n",
    "np.random.seed(seed)\n",
    "torch.manual_seed(seed)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "73de160d-7dbf-4739-8aa7-16e29e0a3299",
   "metadata": {},
   "source": [
    "## Isotropic explicit"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f5c8d596-4bb5-4b00-8c08-9a7ed4cc2a51",
   "metadata": {},
   "outputs": [],
   "source": [
    "########### SET Training data PARAMETERS #################\n",
    "n = 10000\n",
    "d = 50 \n",
    "mu = torch.ones(d, device = device)\n",
    "\n",
    "### Isotropic case\n",
    "scale_var = .5\n",
    "SIGMA = torch.eye(d,device = device) * scale_var\n",
    "\n",
    "training_distribution = func.gaussian(d, mu, SIGMA)\n",
    "training_sample = training_distribution.generate_sample(n)\n",
    "\n",
    "########### SET DIFFUSION AND NOISING PARAMETERS #################\n",
    "sigma_inv = 1\n",
    "beta_min = 0.1\n",
    "beta_max = 20\n",
    "a = 0.\n",
    "T = 1\n",
    "\n",
    "# Set noising function and parameters\n",
    "a_values = np.linspace(-10.,10.,21) #set the parameter values for the beta_parametric function\n",
    "beta = func.beta_parametric(a, T, beta_min, beta_max)\n",
    "#beta = func.beta_cosine(T, 0.003)\n",
    "\n",
    "sde = diff.forward_VPSDE(d, beta, sigma_inv, T, device = device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "46d69020-2277-4b90-b219-e156de9aa68a",
   "metadata": {},
   "outputs": [],
   "source": [
    "################################# SET SAMPLING PARAMETERS #################################\n",
    "sample_batch_size = 10000 # size of the sample generated\n",
    "#init = sde.final.generate_sample(sample_batch_size)\n",
    "num_steps = 500 \n",
    "#xbarT_euler = sp.Euler_Maruyama_discr_sampler(init, sde, diff.explicit_score(sde, training_distribution), num_steps)\n",
    "#xbarT_semii = sp.EI_discr_sampler(init, sde, diff.explicit_score(sde, training_distribution), num_steps)\n",
    "#a_values = np.linspace(-10, 10, 21)\n",
    "#num_steps_vec = np.array([num_steps//4, num_steps//2, num_steps ]) #to adjust as desired"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "36b8b0c3-29a2-47ac-a428-18e3428d0d98",
   "metadata": {},
   "outputs": [],
   "source": [
    "##### EXACT SCORE SIMULATION GAUSSIAN CASE ####\n",
    "num_steps_vec = np.array([num_steps//4, num_steps//2, num_steps ]) #to adjust as desired\n",
    "#distances = []\n",
    "#for i, ns in tqdm(enumerate(num_steps_vec)):\n",
    "#    for j, a in enumerate(a_values):\n",
    "#        sde.beta.change_a(a)\n",
    "#        score_theta = diff.explicit_score(sde, training_distribution)\n",
    "#        for k, scheme in enumerate([sp.Euler_Maruyama_discr_sampler, sp.EI_discr_sampler]):\n",
    "#            sample = scheme(init, sde, score_theta, ns)\n",
    "#            distances.append({\n",
    "#                \"a\": a, \"num_steps\": ns, \n",
    "#                \"scheme\": \"euler\" if k==0 else \"semii\", \n",
    "#                \"loss\": \"exact_score\", \n",
    "#                \"kl\": kl(training_distribution, empirical(sample)), \n",
    "#                \"w2\": w2(training_distribution, empirical(sample)),\n",
    "#            })\n",
    "#distances_df = pd.DataFrame(distances)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7637b768-1bbf-4093-8ac1-31b4bcd853bd",
   "metadata": {},
   "outputs": [],
   "source": [
    "#save\n",
    "#file_path = os.path.join('models/gaussian_iso/d50_explicit_pkl', 'distances_df.pkl')\n",
    "#distances_df.to_pickle(file_path)\n",
    "\n",
    "#load \n",
    "file_path_load = 'models/gaussian_iso/d50_explicit_pkl/distances_df.pkl'\n",
    "distances_df = pd.read_pickle(file_path_load)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c431d31e-91ee-4b33-b7a0-3c5c5a3dc8e2",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axs = plt.subplots(figsize=(8,4), ncols=2, sharey=True, sharex = True)\n",
    "axs[0].set_title(\"KL divergence, exact score, Euler scheme\")\n",
    "axs[1].set_title(\"KL divergence, exact score, Semi-Implicit scheme\")\n",
    "\n",
    "for k, scheme in enumerate([\"euler\", \"semii\"]):\n",
    "    for ns in num_steps_vec:\n",
    "        select = (distances_df[\"scheme\"] == scheme) & (distances_df[\"loss\"] == \"exact_score\") & (distances_df[\"num_steps\"]==ns)\n",
    "        res_kl = distances_df[select].groupby(\"a\")[\"kl\"].mean()\n",
    "        axs[k].plot(a_values, res_kl, label=f\"num_steps = {ns}\")\n",
    "    axs[k].legend()\n",
    "    axs[k].set_xlabel(\"a value\")\n",
    "\n",
    "axs[0].set_ylabel(\"KL divergence\")\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f3504bad-34a8-4894-aa37-15fe9dbdbe0d",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axs = plt.subplots(figsize=(8,4), ncols=2, sharey=True, sharex = True)\n",
    "axs[0].set_title(\"W2 distance, exact score, Euler scheme\")\n",
    "axs[1].set_title(\"W2, exact score, Semi-Implicit scheme\")\n",
    "\n",
    "for k, scheme in enumerate([\"euler\", \"semii\"]):\n",
    "    for ns in num_steps_vec:\n",
    "        select = (distances_df[\"scheme\"] == scheme) & (distances_df[\"loss\"] == \"exact_score\") & (distances_df[\"num_steps\"]==ns)\n",
    "        res_kl = distances_df[select].groupby(\"a\")[\"w2\"].mean()\n",
    "        axs[k].plot(a_values, res_kl, label=f\"num_steps = {ns}\")\n",
    "    axs[k].legend()\n",
    "    axs[k].set_xlabel(\"a value\")\n",
    "\n",
    "axs[0].set_ylabel(\"W2 distance\")\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ce775e44-cf7d-4573-9c76-699a8e1c6104",
   "metadata": {},
   "outputs": [],
   "source": [
    "######## SELECT TRAINING PARAMETERS W.R.T THE DIMENSION ######################\n",
    "if d==5: \n",
    "    # optimisation parameters\n",
    "    n_epochs = 2 \n",
    "    learning_rate = 1.0e-4\n",
    "    \n",
    "elif d==10: \n",
    "    # optimisation parameters\n",
    "    n_epochs = 30\n",
    "    learning_rate = 1.0e-4  \n",
    "    \n",
    "elif d==25: \n",
    "    # optimisation parameters\n",
    "    n_epochs = 50\n",
    "    learning_rate = 1.0e-3\n",
    "    \n",
    "elif d==50: \n",
    "    # optimisation parameters\n",
    "    n_epochs = 150\n",
    "    learning_rate = 1.0e-3\n",
    "\n",
    "# network_parameters\n",
    "mid_features = 256\n",
    "num_layers = 3\n",
    "batch_size = 64\n",
    "\n",
    "# Monte Carlo estimate samples (for E2 computation in the KL bound)\n",
    "num_mc = 500\n",
    "\n",
    "# number of runs (for replication)\n",
    "rep_size = 10\n",
    "\n",
    "# load data\n",
    "dataloader = DataLoader(training_sample, batch_size=batch_size, shuffle=True)\n",
    "\n",
    "# number of discretisation steps\n",
    "num_step = 500"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f92c37b9-f7fa-480c-96e1-a9227486d519",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "###\n",
    "## TRAINING NN FOR DIFFERENT NOISE SCHEDULES\n",
    "###\n",
    "'''\n",
    "score_theta_trained = [ [] for _ in a_values ]   # a_values, replicates\n",
    "for k, a in enumerate(a_values):\n",
    "    print(f\"optimization for a = {a}\")\n",
    "    sde.beta.change_a(a)\n",
    "    score_theta_trained[k] = [ ]\n",
    "    for j in range(rep_size):\n",
    "        network = decoder.Decoder(sde.d, mid_features, num_layers).to(device)\n",
    "        optimizer = Adam(network.parameters(), lr=learning_rate)\n",
    "        loss = diff.loss_conditional(network, sde) # to choose when the score function is not analytically available\n",
    "        #loss = diff.loss_explicit(network, sde, diff.explicit_score(sde, training_distribution)) # to choose in the Gaussian case for training with the analytical score function\n",
    "        diff.train(loss, dataloader, n_epochs=n_epochs, optimizer=optimizer)\n",
    "        score_theta_trained[k].append(network)\n",
    "        #torch.save(network.state_dict(), f'model/d'+str(d)+f'/model_{k}_{j}_d'+str(d)+'.pt')\n",
    "\n",
    "'''\n",
    "##### TO LOAD PRETRAINED MODELS\n",
    "score_theta_trained = [ [] for _ in a_values ]   # a_values, replicates\n",
    "for k, a in enumerate(a_values):\n",
    "    score_theta_trained[k] = [ ]\n",
    "    for j in range(rep_size):\n",
    "        network = decoder.Decoder(sde.d, mid_features, num_layers) \n",
    "        network.load_state_dict(torch.load(f'models/gaussian_iso/d'+str(d)+'_explicit'+f'/model_{k}_{j}_d'+str(d)+'.pt'))\n",
    "        score_theta_trained[k].append(network)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f374786f-7b75-4050-9810-b01cff494e5d",
   "metadata": {},
   "outputs": [],
   "source": [
    "### GAUSSIAN CASE\n",
    "'''\n",
    "simulation_results = []\n",
    "\n",
    "for k, a in tqdm(enumerate(a_values)):\n",
    "    print(f\"simulation and M for a = {a}\")\n",
    "    sde.beta.change_a(a)\n",
    "    exp_score = diff.explicit_score(sde, training_distribution)\n",
    "\n",
    "    for j, score_theta in enumerate(score_theta_trained[k]):\n",
    "        error_approx_E2, error_approx_sup_L2 = func.compute_E2(training_distribution, sde, score_theta, exp_score, num_steps, num_mc)\n",
    "\n",
    "        for scheme_name, scheme_sampler in [(\"euler\", sp.Euler_Maruyama_discr_sampler), (\"semii\", sp.EI_discr_sampler)]:\n",
    "            sample = scheme_sampler(init, sde, score_theta, num_steps)\n",
    "            kl_value = kl(training_distribution, empirical(sample))\n",
    "            w2_value = w2(training_distribution, empirical(sample))\n",
    "\n",
    "            simulation_results.append({\n",
    "                \"a\": a,\n",
    "                \"replication\": j,\n",
    "                \"scheme\": scheme_name,\n",
    "                \"error_approx_E2\": error_approx_E2,\n",
    "                \"error_approx_sup_L2\": error_approx_sup_L2,\n",
    "                \"kl\": kl_value,\n",
    "                \"w2\": w2_value,\n",
    "            })\n",
    "\n",
    "simulation_df = pd.DataFrame(simulation_results)\n",
    "'''"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fdf77043-28e3-4bb9-b75d-55361d5a1432",
   "metadata": {},
   "outputs": [],
   "source": [
    "#save\n",
    "#file_path = os.path.join('models/gaussian_iso/d50_explicit_pkl', 'simulation_df.pkl')\n",
    "#simulation_df.to_pickle(file_path)\n",
    "#load \n",
    "file_path_load = 'models/gaussian_iso/d50_explicit_pkl/simulation_df.pkl'\n",
    "simulation_df = pd.read_pickle(file_path_load)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7e94310b-7e70-40a1-a1ff-284cc87ef896",
   "metadata": {},
   "outputs": [],
   "source": [
    "#epsilon = simulation_df.groupby('a')['error_approx_sup_L2'].mean()\n",
    "#error_results = []\n",
    "\n",
    "#for k, a in tqdm(enumerate(a_values)):\n",
    "#    print(f\"bound for a = {a}\")\n",
    "#    sde.beta.change_a(a)\n",
    "    \n",
    "#    error_mixing_E1 = func.compute_E1(training_distribution, sde)\n",
    "#    error_discr_E3 = func.compute_E3(training_distribution, sde, num_steps)\n",
    "#    w2_error = func.compute_w2_bound(training_distribution, training_sample, sde, num_steps, epsilon[a])\n",
    "    \n",
    "#    error_results.append({\n",
    "#        \"a\": a,\n",
    "#        \"error_mixing_E1\": error_mixing_E1,\n",
    "#        \"error_discr_E3\": error_discr_E3,\n",
    "#        \"w2_error\": w2_error\n",
    "#    })\n",
    "\n",
    "#errors_df = pd.DataFrame(error_results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6a0ed020-0abd-40e4-af60-7bdf96f215d5",
   "metadata": {},
   "outputs": [],
   "source": [
    "#save\n",
    "#file_path = os.path.join('models/gaussian_iso/d50_explicit_pkl', 'errors_df.pkl')\n",
    "#errors_df.to_pickle(file_path)\n",
    "#load \n",
    "file_path_load = 'models/gaussian_iso/d50_explicit_pkl/errors_df.pkl'\n",
    "errors_df = pd.read_pickle(file_path_load)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6137c88a-2f50-498e-963e-626f9d673d68",
   "metadata": {},
   "outputs": [],
   "source": [
    "results_df = pd.merge(simulation_df, errors_df, on=\"a\", how=\"left\")\n",
    "\n",
    "# comutation of the upperbound KL\n",
    "error_mixing_E1 = results_df.groupby('a')['error_mixing_E1'].first()\n",
    "error_approx_E2 = results_df.groupby('a')['error_approx_E2'].mean()\n",
    "error_discr_E3 = results_df.groupby('a')['error_discr_E3'].first() \n",
    "KL_upperbound = error_mixing_E1 + error_approx_E2 + error_discr_E3\n",
    "\n",
    "# computation of the W2 upperbound\n",
    "W2_upperbound = results_df.groupby('a')['w2_error'].mean()\n",
    "\n",
    "# generation results KL \n",
    "mean_all_Euler_KL = results_df[results_df['scheme'] == 'euler'].groupby('a')['kl'].mean()\n",
    "std_all_Euler_KL = results_df[results_df['scheme'] == 'euler'].groupby('a')['kl'].std()\n",
    "mean_all_EI_KL = results_df[results_df['scheme'] == 'semii'].groupby('a')['kl'].mean()\n",
    "std_all_EI_KL = results_df[results_df['scheme'] == 'semii'].groupby('a')['kl'].std()\n",
    "\n",
    "# generation results W2 \n",
    "mean_all_Euler_W2 = results_df[results_df['scheme'] == 'euler'].groupby('a')['w2'].mean()\n",
    "std_all_Euler_W2 = results_df[results_df['scheme'] == 'euler'].groupby('a')['w2'].std()\n",
    "mean_all_EI_W2 = results_df[results_df['scheme'] == 'semii'].groupby('a')['w2'].mean()\n",
    "std_all_EI_W2 = results_df[results_df['scheme'] == 'semii'].groupby('a')['w2'].std()\n",
    "\n",
    "# generation results VPSDE KL\n",
    "\n",
    "VPSDE_mean_KL = mean_all_Euler_KL[0.0]\n",
    "VPSDE_std_KL = std_all_Euler_KL[0.0]\n",
    "\n",
    "# generation results VPSDE W2\n",
    "\n",
    "VPSDE_mean_W2 = mean_all_Euler_W2[0.0]\n",
    "VPSDE_std_W2 = std_all_Euler_W2[0.0]\n",
    "\n",
    "# Exact score results KL \n",
    "true_score_KL_Euler = distances_df[(distances_df[\"scheme\"] == 'euler') & (distances_df[\"loss\"] == \"exact_score\") & (distances_df[\"num_steps\"]==num_steps)]['kl']\n",
    "true_score_KL_EI = distances_df[(distances_df[\"scheme\"] == 'semii') & (distances_df[\"loss\"] == \"exact_score\") & (distances_df[\"num_steps\"]==num_steps)]['kl']\n",
    "\n",
    "# Exact score results W2 \n",
    "true_score_W2_Euler = distances_df[(distances_df[\"scheme\"] == 'euler') & (distances_df[\"loss\"] == \"exact_score\") & (distances_df[\"num_steps\"]==num_steps)]['w2']\n",
    "true_score_W2_EI = distances_df[(distances_df[\"scheme\"] == 'semii') & (distances_df[\"loss\"] == \"exact_score\") & (distances_df[\"num_steps\"]==num_steps)]['w2']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b7df1040-4e36-45c3-bb80-439eef463b3d",
   "metadata": {},
   "outputs": [],
   "source": [
    "##### PLOT KL VS##### PLOT KL VS BOUND\n",
    "font = {'family' : 'sans-serif',\n",
    "        'size'   : 16}\n",
    "plt.rc('font', **font)\n",
    "cmap = cm.get_cmap('tab20c')\n",
    "mpl.rcParams['axes.facecolor'] = 'white'\n",
    "mpl.rcParams['axes.grid'] = False\n",
    "mpl.rcParams['axes.edgecolor'] = 'gray'\n",
    "\n",
    "fig, ax1 = plt.subplots()\n",
    "\n",
    "y1_color = cmap(1/20)\n",
    "ax1.plot(a_values, KL_upperbound, color=y1_color, label=\"Upper-bound\",linewidth=3)\n",
    "ax1.fill_between(a_values,\n",
    "                 ([x.item() for x in KL_upperbound]) - results_df.groupby('a')['error_approx_E2'].std().to_numpy() ,\n",
    "                 ([x.item() for x in KL_upperbound]) + results_df.groupby('a')['error_approx_E2'].std().to_numpy() ,\n",
    "                 fc=y1_color,alpha=.3)\n",
    "ax1.set_xlabel(\"Values of a\")\n",
    "ax1.set_ylabel(\"Upper bound (KL)\", color=y1_color)\n",
    "ax1.tick_params(axis='y', labelcolor=y1_color)\n",
    "\n",
    "y2_color = cmap(5/20)\n",
    "y3_color = cmap(5/20)\n",
    "color_0 = 'gray' #mean_all_Euler_KL\n",
    "color_bis = 'black'\n",
    "ax2 = ax1.twinx()\n",
    "\n",
    "ax2.plot(a_values, mean_all_Euler_KL, color=y2_color, label=\"KL($\\\\pi_{\\\\rm data}$,$\\\\hat{\\\\pi}$) (NN) \",linewidth=3)\n",
    "ax2.plot(a_values, true_score_KL_Euler, color=y2_color, label = \"KL($\\\\pi_{\\\\rm data}$,$\\\\hat{\\pi}$) (exact score)\",linewidth=3, linestyle='dotted')\n",
    "ax2.fill_between(a_values,\n",
    "                 mean_all_Euler_KL-std_all_Euler_KL,\n",
    "                 mean_all_Euler_KL+std_all_Euler_KL,\n",
    "                 fc=y2_color,alpha=.3)\n",
    "ax2.plot(a_values, np.full_like(a_values, VPSDE_mean_KL), color=y3_color, label=\"VPSDE (NN)\",linewidth=2, linestyle='--')\n",
    "ax2.fill_between(a_values,\n",
    "                 VPSDE_mean_KL-VPSDE_std_KL,\n",
    "                 VPSDE_mean_KL+VPSDE_std_KL,\n",
    "                 fc=y3_color,alpha=.2)\n",
    "ax2.set_ylabel(\"KL divergence\", color=y2_color)\n",
    "ax2.tick_params(axis='y', labelcolor=y2_color)\n",
    "\n",
    "fig.tight_layout()\n",
    "plt.legend()\n",
    "#plt.savefig(\"fig/UB_and_KL_EI_classique_covar_d\"+str(d)+\".pdf\")\n",
    "plt.show() "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dd361256-9bc5-4d19-9370-d9b3162e6252",
   "metadata": {},
   "outputs": [],
   "source": [
    "##### PLOT WASSERSTEIN ########\n",
    "font = {'family' : 'sans-serif',\n",
    "        'size'   : 16}\n",
    "plt.rc('font', **font)\n",
    "cmap = cm.get_cmap('tab20c')\n",
    "mpl.rcParams['axes.facecolor'] = 'white'\n",
    "mpl.rcParams['axes.grid'] = False\n",
    "mpl.rcParams['axes.edgecolor'] = 'gray'\n",
    "\n",
    "fig, ax1 = plt.subplots()\n",
    "\n",
    "y1_color = cmap(1/20)\n",
    "ax1.plot(a_values, W2_upperbound, color=y1_color, label=\"Upper-bound\",linewidth=3)\n",
    "ax1.fill_between(a_values,\n",
    "                 ([x for x in W2_upperbound]) - simulation_df.groupby('a')['error_approx_sup_L2'].std().to_numpy() ,\n",
    "                 ([x for x in W2_upperbound]) + simulation_df.groupby('a')['error_approx_sup_L2'].std().to_numpy() ,\n",
    "                 fc=y1_color,alpha=.3)\n",
    "ax1.set_xlabel(\"Values of a\")\n",
    "ax1.set_ylabel(\"Upper bound (W2)\", color=y1_color)\n",
    "ax1.tick_params(axis='y', labelcolor=y1_color)\n",
    "\n",
    "y2_color = cmap(5/20)\n",
    "y3_color = cmap(5/20)\n",
    "color_0 = 'gray' #mean_all_Euler_KL\n",
    "color_bis = 'black'\n",
    "ax2 = ax1.twinx()\n",
    "\n",
    "ax2.plot(a_values, mean_all_Euler_W2, color=y2_color, label= r\"$\\mathcal{W}_2(\\pi_{\\rm data},\\hat{\\pi})$ (NN) \",linewidth=3)\n",
    "ax2.plot(a_values, true_score_W2_Euler, color=y2_color, label = r\"$\\mathcal{W}_2(\\pi_{\\rm data}$,$\\hat{\\pi}$) (exact score)\",linewidth=3, linestyle='dotted')\n",
    "ax2.fill_between(a_values,\n",
    "                 mean_all_Euler_W2-std_all_Euler_W2,\n",
    "                 mean_all_Euler_W2+std_all_Euler_W2,\n",
    "                 fc=y2_color,alpha=.3)\n",
    "ax2.plot(a_values, np.full_like(a_values, VPSDE_mean_W2), color=y3_color, label=\"VPSDE (NN)\",linewidth=2, linestyle='--')\n",
    "ax2.fill_between(a_values,\n",
    "                 VPSDE_mean_W2-VPSDE_std_W2,\n",
    "                 VPSDE_mean_W2+VPSDE_std_W2,\n",
    "                 fc=y3_color,alpha=.2)\n",
    "ax2.set_ylabel(r\"$\\mathcal{W}_2$ distance\", color=y2_color)\n",
    "ax2.tick_params(axis='y', labelcolor=y2_color)\n",
    "\n",
    "fig.tight_layout()\n",
    "\n",
    "plt.legend(loc = 'upper right')\n",
    "#plt.savefig(\"isotropic_num_steps_500_W_2_C_t+L_t_corr_M1.pdf\")\n",
    "plt.show() "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9bd2723c-a8a4-4ce7-8eb8-006018f4f900",
   "metadata": {},
   "source": [
    "## Isotropic Conditional"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "944a213d-4e6f-4ac8-876a-ddf6e09a298c",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "###\n",
    "## TRAINING NN FOR DIFFERENT NOISE SCHEDULES\n",
    "###\n",
    "'''\n",
    "score_theta_trained = [ [] for _ in a_values ]   # a_values, replicates\n",
    "for k, a in enumerate(a_values):\n",
    "    print(f\"optimization for a = {a}\")\n",
    "    sde.beta.change_a(a)\n",
    "    score_theta_trained[k] = [ ]\n",
    "    for j in range(rep_size):\n",
    "        network = decoder.Decoder(sde.d, mid_features, num_layers).to(device)\n",
    "        optimizer = Adam(network.parameters(), lr=learning_rate)\n",
    "        loss = diff.loss_conditional(network, sde) # to choose when the score function is not analytically available\n",
    "        #loss = diff.loss_explicit(network, sde, diff.explicit_score(sde, training_distribution)) # to choose in the Gaussian case for training with the analytical score function\n",
    "        diff.train(loss, dataloader, n_epochs=n_epochs, optimizer=optimizer)\n",
    "        score_theta_trained[k].append(network)\n",
    "        torch.save(network.state_dict(), f'models/gaussian_iso/d'+str(d)+f'_cond_ep'+str(n_epochs)+f'/model_{k}_{j}_d'+str(d)+'.pt')\n",
    "'''\n",
    "##### TO LOAD PRETRAINED MODELS\n",
    "score_theta_trained = [ [] for _ in a_values ]   # a_values, replicates\n",
    "for k, a in enumerate(a_values):\n",
    "    score_theta_trained[k] = [ ]\n",
    "    for j in range(rep_size):\n",
    "        network = decoder.Decoder(sde.d, mid_features, num_layers) \n",
    "        network.load_state_dict(torch.load(f'models/gaussian_iso/d'+str(d)+f'_cond_ep'+str(n_epochs)+f'/model_{k}_{j}_d'+str(d)+'.pt'))\n",
    "        score_theta_trained[k].append(network)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e6e1d847-65eb-45c3-83c2-52b58165392c",
   "metadata": {},
   "outputs": [],
   "source": [
    "### GAUSSIAN CASE\n",
    "\n",
    "simulation_results = []\n",
    "\n",
    "for k, a in tqdm(enumerate(a_values)):\n",
    "    print(f\"simulation and M for a = {a}\")\n",
    "    sde.beta.change_a(a)\n",
    "    exp_score = diff.explicit_score(sde, training_distribution)\n",
    "\n",
    "    for j, score_theta in enumerate(score_theta_trained[k]):\n",
    "        #error_approx_E2 = func.compute_E2(training_distribution, sde, score_theta, exp_score, num_steps, num_mc)\n",
    "        error_approx_E2, error_approx_sup_L2 = func.compute_E2(training_distribution, sde, score_theta, exp_score, num_steps, num_mc)\n",
    "\n",
    "        for scheme_name, scheme_sampler in [(\"euler\", sp.Euler_Maruyama_discr_sampler), (\"semii\", sp.EI_discr_sampler)]:\n",
    "            sample = scheme_sampler(init, sde, score_theta, num_steps)\n",
    "            kl_value = kl(training_distribution, empirical(sample))\n",
    "            w2_value = w2(training_distribution, empirical(sample))\n",
    "\n",
    "            simulation_results.append({\n",
    "                \"a\": a,\n",
    "                \"replication\": j,\n",
    "                \"scheme\": scheme_name,\n",
    "                \"error_approx_E2\": error_approx_E2,\n",
    "                \"error_approx_sup_L2\": error_approx_sup_L2,\n",
    "                \"kl\": kl_value,\n",
    "                \"w2\": w2_value,\n",
    "            })\n",
    "\n",
    "simulation_df = pd.DataFrame(simulation_results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e3b95598-1bac-481c-86e1-622645426dfc",
   "metadata": {},
   "outputs": [],
   "source": [
    "#save\n",
    "#file_path = os.path.join(f'models/gaussian_iso/d'+str(d)+f'_cond_ep'+str(n_epochs)+'_pkl', 'simulation_df.pkl')\n",
    "#simulation_df_cond = simulation_df.to_pickle(file_path)\n",
    "#load \n",
    "'''\n",
    "file_path = os.path.join('models/gaussian_iso/d50_explicit_pkl', 'simulation_df.pkl')\n",
    "simulation_df.to_pickle(file_path)\n",
    "n_epochs = 150\n",
    "file_path_load = f'models/gaussian_iso/d'+str(d)+f'_cond_ep'+str(n_epochs)+'_pkl/simulation_df.pkl'\n",
    "simulation_df_cond_150 = pd.read_pickle(file_path_load)\n",
    "n_epochs = 300\n",
    "file_path_load = f'models/gaussian_iso/d'+str(d)+f'_cond_ep'+str(n_epochs)+'_pkl/simulation_df.pkl'\n",
    "simulation_df_cond_300 = pd.read_pickle(file_path_load)\n",
    "#n_epochs = 500\n",
    "#file_path_load = f'models/gaussian_iso/d'+str(d)+f'_cond_ep'+str(n_epochs)+'_pkl/simulation_df.pkl'\n",
    "#simulation_df_cond_300 = pd.read_pickle(file_path_load)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "891f9e9f-8722-4682-935c-cb230b77d53d",
   "metadata": {},
   "outputs": [],
   "source": [
    "epsilon = simulation_df.groupby('a')['error_approx_sup_L2'].mean()\n",
    "error_results = []\n",
    "\n",
    "for k, a in tqdm(enumerate(a_values)):\n",
    "    print(f\"bound for a = {a}\")\n",
    "    sde.beta.change_a(a)\n",
    "    exp_score = diff.explicit_score(sde, training_distribution)\n",
    "    \n",
    "    error_mixing_E1 = func.compute_E1(training_distribution, sde)\n",
    "    error_discr_E3 = func.compute_E3(training_distribution, sde, num_steps)\n",
    "    w2_error = func.compute_w2_bound(training_distribution, training_sample, sde, num_steps, epsilon[a])\n",
    "    \n",
    "    error_results.append({\n",
    "        \"a\": a,\n",
    "        \"error_mixing_E1\": error_mixing_E1,\n",
    "        \"error_discr_E3\": error_discr_E3,\n",
    "        \"w2_error\": w2_error\n",
    "    })\n",
    "\n",
    "errors_df = pd.DataFrame(error_results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ae937e80-1d4d-4b61-8ddc-2921b6a15c59",
   "metadata": {},
   "outputs": [],
   "source": [
    "#save\n",
    "#file_path = os.path.join('models/gaussian_iso/d50_explicit_pkl', 'errors_df.pkl')\n",
    "#simulation_df.to_pickle(file_path)\n",
    "\n",
    "#load \n",
    "#### explicit\n",
    "n_epochs = 150\n",
    "\n",
    "file_path_load = 'models/gaussian_iso/d50_explicit_pkl/distances_df.pkl'\n",
    "distances_df = pd.read_pickle(file_path_load)\n",
    "file_path_load = 'models/gaussian_iso/d50_explicit_pkl/simulation_df.pkl'\n",
    "simulation_df = pd.read_pickle(file_path_load)\n",
    "file_path_load = 'models/gaussian_iso/d50_explicit_pkl/errors_df.pkl'\n",
    "errors_df = pd.read_pickle(file_path_load)\n",
    "\n",
    "####conditional 150 epochs\n",
    "\n",
    "n_epochs = 150\n",
    "file_path_load = f'models/gaussian_iso/d'+str(d)+f'_cond_ep'+str(n_epochs)+'_pkl/simulation_df.pkl'\n",
    "simulation_df_cond_150 = pd.read_pickle(file_path_load)\n",
    "file_path_load =  f'models/gaussian_iso/d'+str(d)+f'_cond_ep'+str(n_epochs)+'_pkl/errors_df_cond_150.pkl'\n",
    "errors_df_cond_150 = pd.read_pickle(file_path_load)\n",
    "\n",
    "####conditional 300 epochs\n",
    "n_epochs = 300\n",
    "file_path_load = f'models/gaussian_iso/d'+str(d)+f'_cond_ep'+str(n_epochs)+'_pkl/simulation_df.pkl'\n",
    "simulation_df_cond_300 = pd.read_pickle(file_path_load)\n",
    "file_path_load =  f'models/gaussian_iso/d'+str(d)+f'_cond_ep'+str(n_epochs)+'_pkl/errors_df_cond_300.pkl'\n",
    "errors_df_cond_300 = pd.read_pickle(file_path_load)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d3b2a43a-49a6-44d8-9840-c25bdbe246dd",
   "metadata": {},
   "outputs": [],
   "source": [
    "results_df = pd.merge(simulation_df, errors_df, on=\"a\", how=\"left\")\n",
    "# comutation of the upperbound KL\n",
    "error_mixing_E1 = results_df.groupby('a')['error_mixing_E1'].first()\n",
    "error_approx_E2 = results_df.groupby('a')['error_approx_E2'].mean()\n",
    "error_discr_E3 = results_df.groupby('a')['error_discr_E3'].first()\n",
    "KL_upperbound = error_mixing_E1 + error_approx_E2 + error_discr_E3\n",
    "\n",
    "# computation of the W2 upperbound\n",
    "W2_upperbound = results_df.groupby('a')['w2_error'].mean()\n",
    "\n",
    "# generation results KL \n",
    "mean_all_Euler_KL = results_df[results_df['scheme'] == 'euler'].groupby('a')['kl'].mean()\n",
    "std_all_Euler_KL = results_df[results_df['scheme'] == 'euler'].groupby('a')['kl'].std()\n",
    "mean_all_EI_KL = results_df[results_df['scheme'] == 'semii'].groupby('a')['kl'].mean()\n",
    "std_all_EI_KL = results_df[results_df['scheme'] == 'semii'].groupby('a')['kl'].std()\n",
    "\n",
    "# generation results W2 \n",
    "mean_all_Euler_W2 = results_df[results_df['scheme'] == 'euler'].groupby('a')['w2'].mean()\n",
    "std_all_Euler_W2 = results_df[results_df['scheme'] == 'euler'].groupby('a')['w2'].std()\n",
    "mean_all_EI_W2 = results_df[results_df['scheme'] == 'semii'].groupby('a')['w2'].mean()\n",
    "std_all_EI_W2 = results_df[results_df['scheme'] == 'semii'].groupby('a')['w2'].std()\n",
    "\n",
    "# generation results VPSDE KL\n",
    "\n",
    "VPSDE_mean_KL = mean_all_Euler_KL[0.0]\n",
    "VPSDE_std_KL = std_all_Euler_KL[0.0]\n",
    "\n",
    "# generation results VPSDE W2\n",
    "\n",
    "VPSDE_mean_W2 = mean_all_Euler_W2[0.0]\n",
    "VPSDE_std_W2 = std_all_Euler_W2[0.0]\n",
    "\n",
    "# Exact score results KL \n",
    "true_score_KL_Euler = distances_df[(distances_df[\"scheme\"] == 'euler') & (distances_df[\"loss\"] == \"exact_score\") & (distances_df[\"num_steps\"]==num_steps)]['kl']\n",
    "true_score_KL_EI = distances_df[(distances_df[\"scheme\"] == 'semii') & (distances_df[\"loss\"] == \"exact_score\") & (distances_df[\"num_steps\"]==num_steps)]['kl']\n",
    "\n",
    "# Exact score results W2 \n",
    "true_score_W2_Euler = distances_df[(distances_df[\"scheme\"] == 'euler') & (distances_df[\"loss\"] == \"exact_score\") & (distances_df[\"num_steps\"]==num_steps)]['w2']\n",
    "true_score_W2_EI = distances_df[(distances_df[\"scheme\"] == 'semii') & (distances_df[\"loss\"] == \"exact_score\") & (distances_df[\"num_steps\"]==num_steps)]['w2']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7283d080-0af7-4933-981c-4f1788b8434c",
   "metadata": {},
   "outputs": [],
   "source": [
    "results_df_cond_150 = pd.merge(simulation_df_cond_150, errors_df_cond_150, on=\"a\", how=\"left\")\n",
    "# comutation of the upperbound KL\n",
    "error_mixing_E1 = results_df_cond_150.groupby('a')['error_mixing_E1'].first()\n",
    "error_approx_E2 = results_df_cond_150.groupby('a')['error_approx_E2'].mean()\n",
    "error_discr_E3 = results_df_cond_150.groupby('a')['error_discr_E3'].first()\n",
    "KL_upperbound = error_mixing_E1 + error_approx_E2 + error_discr_E3\n",
    "\n",
    "# computation of the W2 upperbound\n",
    "W2_upperbound = results_df_cond_150.groupby('a')['w2_error'].mean()\n",
    "\n",
    "# generation results KL \n",
    "mean_all_Euler_KL = results_df_cond_150[results_df_cond_150['scheme'] == 'euler'].groupby('a')['kl'].mean()\n",
    "std_all_Euler_KL = results_df_cond_150[results_df_cond_150['scheme'] == 'euler'].groupby('a')['kl'].std()\n",
    "mean_all_EI_KL = results_df_cond_150[results_df_cond_150['scheme'] == 'semii'].groupby('a')['kl'].mean()\n",
    "std_all_EI_KL = results_df_cond_150[results_df_cond_150['scheme'] == 'semii'].groupby('a')['kl'].std()\n",
    "\n",
    "# generation results W2 \n",
    "mean_all_Euler_W2 = results_df_cond_150[results_df_cond_150['scheme'] == 'euler'].groupby('a')['w2'].mean()\n",
    "std_all_Euler_W2 = results_df_cond_150[results_df_cond_150['scheme'] == 'euler'].groupby('a')['w2'].std()\n",
    "mean_all_EI_W2 = results_df_cond_150[results_df_cond_150['scheme'] == 'semii'].groupby('a')['w2'].mean()\n",
    "std_all_EI_W2 = results_df_cond_150[results_df_cond_150['scheme'] == 'semii'].groupby('a')['w2'].std()\n",
    "\n",
    "# generation results VPSDE KL\n",
    "\n",
    "VPSDE_mean_KL = mean_all_Euler_KL[0.0]\n",
    "VPSDE_std_KL = std_all_Euler_KL[0.0]\n",
    "\n",
    "# generation results VPSDE W2\n",
    "\n",
    "VPSDE_mean_W2 = mean_all_Euler_W2[0.0]\n",
    "VPSDE_std_W2 = std_all_Euler_W2[0.0]\n",
    "\n",
    "# Exact score results KL \n",
    "true_score_KL_Euler = distances_df[(distances_df[\"scheme\"] == 'euler') & (distances_df[\"loss\"] == \"exact_score\") & (distances_df[\"num_steps\"]==num_steps)]['kl']\n",
    "true_score_KL_EI = distances_df[(distances_df[\"scheme\"] == 'semii') & (distances_df[\"loss\"] == \"exact_score\") & (distances_df[\"num_steps\"]==num_steps)]['kl']\n",
    "\n",
    "# Exact score results W2 \n",
    "true_score_W2_Euler = distances_df[(distances_df[\"scheme\"] == 'euler') & (distances_df[\"loss\"] == \"exact_score\") & (distances_df[\"num_steps\"]==num_steps)]['w2']\n",
    "true_score_W2_EI = distances_df[(distances_df[\"scheme\"] == 'semii') & (distances_df[\"loss\"] == \"exact_score\") & (distances_df[\"num_steps\"]==num_steps)]['w2']\n",
    "\n",
    "\n",
    "\n",
    "##### PLOT KL VS##### PLOT KL VS BOUND\n",
    "font = {'family' : 'sans-serif',\n",
    "        'size'   : 16}\n",
    "plt.rc('font', **font)\n",
    "cmap = cm.get_cmap('tab20c')\n",
    "mpl.rcParams['axes.facecolor'] = 'white'\n",
    "mpl.rcParams['axes.grid'] = False\n",
    "mpl.rcParams['axes.edgecolor'] = 'gray'\n",
    "\n",
    "fig, ax1 = plt.subplots()\n",
    "\n",
    "y1_color = cmap(1/20)\n",
    "ax1.plot(a_values, KL_upperbound, color=y1_color, label=\"Upper-bound\",linewidth=3)\n",
    "ax1.fill_between(a_values,\n",
    "                 ([x.item() for x in KL_upperbound]) - results_df.groupby('a')['error_approx_E2'].std().to_numpy() ,\n",
    "                 ([x.item() for x in KL_upperbound]) + results_df.groupby('a')['error_approx_E2'].std().to_numpy() ,\n",
    "                 fc=y1_color,alpha=.3)\n",
    "ax1.set_xlabel(\"Values of a\")\n",
    "ax1.set_ylabel(\"Upper bound (KL)\", color=y1_color)\n",
    "ax1.tick_params(axis='y', labelcolor=y1_color)\n",
    "\n",
    "y2_color = cmap(5/20)\n",
    "y3_color = cmap(5/20)\n",
    "color_0 = 'gray' #mean_all_Euler_KL\n",
    "color_bis = 'black'\n",
    "ax2 = ax1.twinx()\n",
    "\n",
    "ax2.plot(a_values, mean_all_Euler_KL, color=y2_color, label=\"KL($\\\\pi_{\\\\rm data}$,$\\\\hat{\\\\pi}$) (NN) \",linewidth=3)\n",
    "ax2.plot(a_values, true_score_KL_Euler, color=y2_color, label = \"KL($\\\\pi_{\\\\rm data}$,$\\\\hat{\\pi}$) (exact score)\",linewidth=3, linestyle='dotted')\n",
    "ax2.fill_between(a_values,\n",
    "                 mean_all_Euler_KL-std_all_Euler_KL,\n",
    "                 mean_all_Euler_KL+std_all_Euler_KL,\n",
    "                 fc=y2_color,alpha=.3)\n",
    "ax2.plot(a_values, np.full_like(a_values, VPSDE_mean_KL), color=y3_color, label=\"VPSDE (NN)\",linewidth=2, linestyle='--')\n",
    "ax2.fill_between(a_values,\n",
    "                 VPSDE_mean_KL-VPSDE_std_KL,\n",
    "                 VPSDE_mean_KL+VPSDE_std_KL,\n",
    "                 fc=y3_color,alpha=.2)\n",
    "ax2.set_ylabel(\"KL divergence\", color=y2_color)\n",
    "ax2.tick_params(axis='y', labelcolor=y2_color)\n",
    "\n",
    "fig.tight_layout()\n",
    "plt.legend()\n",
    "#plt.savefig(\"fig/UB_and_KL_EI_classique_covar_d\"+str(d)+\".pdf\")\n",
    "plt.show() \n",
    "\n",
    "##### PLOT WASSERSTEIN ########\n",
    "font = {'family' : 'sans-serif',\n",
    "        'size'   : 16}\n",
    "plt.rc('font', **font)\n",
    "cmap = cm.get_cmap('tab20c')\n",
    "mpl.rcParams['axes.facecolor'] = 'white'\n",
    "mpl.rcParams['axes.grid'] = False\n",
    "mpl.rcParams['axes.edgecolor'] = 'gray'\n",
    "\n",
    "fig, ax1 = plt.subplots()\n",
    "\n",
    "y1_color = cmap(1/20)\n",
    "ax1.plot(a_values, W2_upperbound, color=y1_color, label=\"Upper-bound\",linewidth=3)\n",
    "#ax1.fill_between(a_values,\n",
    "#                 ([x.item() for x in W2_upperbound]) - results_df.groupby('a')['error_approx_E2'].std().to_numpy() ,\n",
    "#                 ([x.item() for x in W2_upperbound]) + results_df.groupby('a')['error_approx_E2'].std().to_numpy() ,\n",
    "#                 fc=y1_color,alpha=.3)\n",
    "ax1.set_xlabel(\"Values of a\")\n",
    "ax1.set_ylabel(\"Upper bound (W2)\", color=y1_color)\n",
    "ax1.tick_params(axis='y', labelcolor=y1_color)\n",
    "\n",
    "y2_color = cmap(5/20)\n",
    "y3_color = cmap(5/20)\n",
    "color_0 = 'gray' #mean_all_Euler_KL\n",
    "color_bis = 'black'\n",
    "ax2 = ax1.twinx()\n",
    "\n",
    "ax2.plot(a_values, mean_all_Euler_W2, color=y2_color, label=\"KL($\\\\pi_{\\\\rm data}$,$\\\\hat{\\\\pi}$) (NN) \",linewidth=3)\n",
    "ax2.plot(a_values, true_score_W2_Euler, color=y2_color, label = \"KL($\\\\pi_{\\\\rm data}$,$\\\\hat{\\pi}$) (exact score)\",linewidth=3, linestyle='dotted')\n",
    "ax2.fill_between(a_values,\n",
    "                 mean_all_Euler_W2-std_all_Euler_W2,\n",
    "                 mean_all_Euler_W2+std_all_Euler_W2,\n",
    "                 fc=y2_color,alpha=.3)\n",
    "ax2.plot(a_values, np.full_like(a_values, VPSDE_mean_W2), color=y3_color, label=\"VPSDE (NN)\",linewidth=2, linestyle='--')\n",
    "ax2.fill_between(a_values,\n",
    "                 VPSDE_mean_W2-VPSDE_std_W2,\n",
    "                 VPSDE_mean_W2+VPSDE_std_W2,\n",
    "                 fc=y3_color,alpha=.2)\n",
    "ax2.set_ylabel(\"KL divergence\", color=y2_color)\n",
    "ax2.tick_params(axis='y', labelcolor=y2_color)\n",
    "\n",
    "fig.tight_layout()\n",
    "#plt.legend()\n",
    "#plt.savefig(\"fig/UB_and_KL_EI_classique_covar_d\"+str(d)+\".pdf\")\n",
    "plt.show() \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7994e9a0-8f1c-46e3-b4b7-b4c23b1484a8",
   "metadata": {},
   "outputs": [],
   "source": [
    "results_df_cond_300 = pd.merge(simulation_df_cond_300, errors_df_cond_300, on=\"a\", how=\"left\")\n",
    "# comutation of the upperbound KL\n",
    "error_mixing_E1 = results_df_cond_300.groupby('a')['error_mixing_E1'].first()\n",
    "error_approx_E2 = results_df_cond_300.groupby('a')['error_approx_E2'].mean()\n",
    "error_discr_E3 = results_df_cond_300.groupby('a')['error_discr_E3'].first()\n",
    "KL_upperbound = error_mixing_E1 + error_approx_E2 + error_discr_E3\n",
    "\n",
    "# computation of the W2 upperbound\n",
    "W2_upperbound = results_df_cond_300.groupby('a')['w2_error'].mean()\n",
    "\n",
    "# generation results KL \n",
    "mean_all_Euler_KL = results_df_cond_300[results_df_cond_300['scheme'] == 'euler'].groupby('a')['kl'].mean()\n",
    "std_all_Euler_KL = results_df_cond_300[results_df_cond_300['scheme'] == 'euler'].groupby('a')['kl'].std()\n",
    "mean_all_EI_KL = results_df_cond_300[results_df_cond_300['scheme'] == 'semii'].groupby('a')['kl'].mean()\n",
    "std_all_EI_KL = results_df_cond_300[results_df_cond_300['scheme'] == 'semii'].groupby('a')['kl'].std()\n",
    "\n",
    "# generation results W2 \n",
    "mean_all_Euler_W2 = results_df_cond_300[results_df_cond_300['scheme'] == 'euler'].groupby('a')['w2'].mean()\n",
    "std_all_Euler_W2 = results_df_cond_300[results_df_cond_300['scheme'] == 'euler'].groupby('a')['w2'].std()\n",
    "mean_all_EI_W2 = results_df_cond_300[results_df_cond_300['scheme'] == 'semii'].groupby('a')['w2'].mean()\n",
    "std_all_EI_W2 = results_df_cond_300[results_df_cond_300['scheme'] == 'semii'].groupby('a')['w2'].std()\n",
    "\n",
    "# generation results VPSDE KL\n",
    "\n",
    "VPSDE_mean_KL = mean_all_Euler_KL[0.0]\n",
    "VPSDE_std_KL = std_all_Euler_KL[0.0]\n",
    "\n",
    "# generation results VPSDE W2\n",
    "\n",
    "VPSDE_mean_W2 = mean_all_Euler_W2[0.0]\n",
    "VPSDE_std_W2 = std_all_Euler_W2[0.0]\n",
    "\n",
    "# Exact score results KL \n",
    "true_score_KL_Euler = distances_df[(distances_df[\"scheme\"] == 'euler') & (distances_df[\"loss\"] == \"exact_score\") & (distances_df[\"num_steps\"]==num_steps)]['kl']\n",
    "true_score_KL_EI = distances_df[(distances_df[\"scheme\"] == 'semii') & (distances_df[\"loss\"] == \"exact_score\") & (distances_df[\"num_steps\"]==num_steps)]['kl']\n",
    "\n",
    "# Exact score results W2 \n",
    "true_score_W2_Euler = distances_df[(distances_df[\"scheme\"] == 'euler') & (distances_df[\"loss\"] == \"exact_score\") & (distances_df[\"num_steps\"]==num_steps)]['w2']\n",
    "true_score_W2_EI = distances_df[(distances_df[\"scheme\"] == 'semii') & (distances_df[\"loss\"] == \"exact_score\") & (distances_df[\"num_steps\"]==num_steps)]['w2']\n",
    "\n",
    "\n",
    "\n",
    "##### PLOT KL VS##### PLOT KL VS BOUND\n",
    "font = {'family' : 'sans-serif',\n",
    "        'size'   : 16}\n",
    "plt.rc('font', **font)\n",
    "cmap = cm.get_cmap('tab20c')\n",
    "mpl.rcParams['axes.facecolor'] = 'white'\n",
    "mpl.rcParams['axes.grid'] = False\n",
    "mpl.rcParams['axes.edgecolor'] = 'gray'\n",
    "\n",
    "fig, ax1 = plt.subplots()\n",
    "\n",
    "y1_color = cmap(1/20)\n",
    "ax1.plot(a_values, KL_upperbound, color=y1_color, label=\"Upper-bound\",linewidth=3)\n",
    "ax1.fill_between(a_values,\n",
    "                 ([x.item() for x in KL_upperbound]) - results_df.groupby('a')['error_approx_E2'].std().to_numpy() ,\n",
    "                 ([x.item() for x in KL_upperbound]) + results_df.groupby('a')['error_approx_E2'].std().to_numpy() ,\n",
    "                 fc=y1_color,alpha=.3)\n",
    "ax1.set_xlabel(\"Values of a\")\n",
    "ax1.set_ylabel(\"Upper bound (KL)\", color=y1_color)\n",
    "ax1.tick_params(axis='y', labelcolor=y1_color)\n",
    "\n",
    "y2_color = cmap(5/20)\n",
    "y3_color = cmap(5/20)\n",
    "color_0 = 'gray' #mean_all_Euler_KL\n",
    "color_bis = 'black'\n",
    "ax2 = ax1.twinx()\n",
    "\n",
    "ax2.plot(a_values, mean_all_Euler_KL, color=y2_color, label=\"KL($\\\\pi_{\\\\rm data}$,$\\\\hat{\\\\pi}$) (NN) \",linewidth=3)\n",
    "ax2.plot(a_values, true_score_KL_Euler, color=y2_color, label = \"KL($\\\\pi_{\\\\rm data}$,$\\\\hat{\\pi}$) (exact score)\",linewidth=3, linestyle='dotted')\n",
    "ax2.fill_between(a_values,\n",
    "                 mean_all_Euler_KL-std_all_Euler_KL,\n",
    "                 mean_all_Euler_KL+std_all_Euler_KL,\n",
    "                 fc=y2_color,alpha=.3)\n",
    "ax2.plot(a_values, np.full_like(a_values, VPSDE_mean_KL), color=y3_color, label=\"VPSDE (NN)\",linewidth=2, linestyle='--')\n",
    "ax2.fill_between(a_values,\n",
    "                 VPSDE_mean_KL-VPSDE_std_KL,\n",
    "                 VPSDE_mean_KL+VPSDE_std_KL,\n",
    "                 fc=y3_color,alpha=.2)\n",
    "ax2.set_ylabel(\"KL divergence\", color=y2_color)\n",
    "ax2.tick_params(axis='y', labelcolor=y2_color)\n",
    "\n",
    "fig.tight_layout()\n",
    "plt.legend()\n",
    "#plt.savefig(\"fig/UB_and_KL_EI_classique_covar_d\"+str(d)+\".pdf\")\n",
    "plt.show() \n",
    "\n",
    "##### PLOT WASSERSTEIN ########\n",
    "font = {'family' : 'sans-serif',\n",
    "        'size'   : 16}\n",
    "plt.rc('font', **font)\n",
    "cmap = cm.get_cmap('tab20c')\n",
    "mpl.rcParams['axes.facecolor'] = 'white'\n",
    "mpl.rcParams['axes.grid'] = False\n",
    "mpl.rcParams['axes.edgecolor'] = 'gray'\n",
    "\n",
    "fig, ax1 = plt.subplots()\n",
    "\n",
    "y1_color = cmap(1/20)\n",
    "ax1.plot(a_values, W2_upperbound, color=y1_color, label=\"Upper-bound\",linewidth=3)\n",
    "#ax1.fill_between(a_values,\n",
    "#                 ([x.item() for x in W2_upperbound]) - results_df.groupby('a')['error_approx_E2'].std().to_numpy() ,\n",
    "#                 ([x.item() for x in W2_upperbound]) + results_df.groupby('a')['error_approx_E2'].std().to_numpy() ,\n",
    "#                 fc=y1_color,alpha=.3)\n",
    "ax1.set_xlabel(\"Values of a\")\n",
    "ax1.set_ylabel(\"Upper bound (W2)\", color=y1_color)\n",
    "ax1.tick_params(axis='y', labelcolor=y1_color)\n",
    "\n",
    "y2_color = cmap(5/20)\n",
    "y3_color = cmap(5/20)\n",
    "color_0 = 'gray' #mean_all_Euler_KL\n",
    "color_bis = 'black'\n",
    "ax2 = ax1.twinx()\n",
    "\n",
    "ax2.plot(a_values, mean_all_Euler_W2, color=y2_color, label=\"KL($\\\\pi_{\\\\rm data}$,$\\\\hat{\\\\pi}$) (NN) \",linewidth=3)\n",
    "ax2.plot(a_values, true_score_W2_Euler, color=y2_color, label = \"KL($\\\\pi_{\\\\rm data}$,$\\\\hat{\\pi}$) (exact score)\",linewidth=3, linestyle='dotted')\n",
    "ax2.fill_between(a_values,\n",
    "                 mean_all_Euler_W2-std_all_Euler_W2,\n",
    "                 mean_all_Euler_W2+std_all_Euler_W2,\n",
    "                 fc=y2_color,alpha=.3)\n",
    "ax2.plot(a_values, np.full_like(a_values, VPSDE_mean_W2), color=y3_color, label=\"VPSDE (NN)\",linewidth=2, linestyle='--')\n",
    "ax2.fill_between(a_values,\n",
    "                 VPSDE_mean_W2-VPSDE_std_W2,\n",
    "                 VPSDE_mean_W2+VPSDE_std_W2,\n",
    "                 fc=y3_color,alpha=.2)\n",
    "ax2.set_ylabel(\"KL divergence\", color=y2_color)\n",
    "ax2.tick_params(axis='y', labelcolor=y2_color)\n",
    "\n",
    "fig.tight_layout()\n",
    "#plt.legend()\n",
    "#plt.savefig(\"fig/UB_and_KL_EI_classique_covar_d\"+str(d)+\".pdf\")\n",
    "plt.show() \n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "67bca5ea-b9f5-417c-8a32-e4292d652fa4",
   "metadata": {},
   "outputs": [],
   "source": [
    "#file_path = os.path.join('models/gaussian_iso/d50_explicit_pkl', 'errors_df.pkl')\n",
    "#simulation_df.to_pickle(file_path)\n",
    "#n_epochs = 300\n",
    "#file_path = os.path.join(f'models/gaussian_iso/d'+str(d)+f'_cond_ep'+str(n_epochs)+'_pkl', 'errors_df_cond_300.pkl')\n",
    "#errors_df_cond_300.to_pickle(file_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9d60c2db-ad25-416f-8291-6aaa0ca9347f",
   "metadata": {},
   "outputs": [],
   "source": [
    "'''\n",
    "epsilon = simulation_df_cond_300.groupby('a')['error_approx_sup_L2'].mean()\n",
    "error_results = []\n",
    "\n",
    "for k, a in tqdm(enumerate(a_values)):\n",
    "    print(f\"bound for a = {a}\")\n",
    "    sde.beta.change_a(a)\n",
    "    exp_score = diff.explicit_score(sde, training_distribution)\n",
    "    \n",
    "    error_mixing_E1 = func.compute_E1(training_distribution, sde)\n",
    "    error_discr_E3 = func.compute_E3(training_distribution, sde, num_steps)\n",
    "    w2_error = func.compute_w2_bound(training_distribution, training_sample, sde, num_steps, epsilon[a])\n",
    "    \n",
    "    error_results.append({\n",
    "        \"a\": a,\n",
    "        \"error_mixing_E1\": error_mixing_E1,\n",
    "        \"error_discr_E3\": error_discr_E3,\n",
    "        \"w2_error\": w2_error\n",
    "    })\n",
    "\n",
    "errors_df_cond_300 = pd.DataFrame(error_results)\n",
    "'''"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8d0371ea-30f4-43ae-9574-a85f8e838384",
   "metadata": {},
   "outputs": [],
   "source": [
    "results_df = pd.merge(simulation_df, errors_df, on=\"a\", how=\"left\")\n",
    "\n",
    "# comutation of the upperbound KL\n",
    "error_mixing_E1 = results_df.groupby('a')['error_mixing_E1'].first()\n",
    "error_approx_E2 = results_df.groupby('a')['error_approx_E2'].mean()\n",
    "error_discr_E3 = results_df.groupby('a')['error_discr_E3'].first()\n",
    "KL_upperbound = error_mixing_E1 + error_approx_E2 + error_discr_E3\n",
    "\n",
    "# computation of the W2 upperbound\n",
    "W2_upperbound = results_df.groupby('a')['w2_error'].mean()\n",
    "\n",
    "# generation results KL \n",
    "mean_all_Euler_KL = results_df[results_df['scheme'] == 'euler'].groupby('a')['kl'].mean()\n",
    "std_all_Euler_KL = results_df[results_df['scheme'] == 'euler'].groupby('a')['kl'].std()\n",
    "mean_all_EI_KL = results_df[results_df['scheme'] == 'semii'].groupby('a')['kl'].mean()\n",
    "std_all_EI_KL = results_df[results_df['scheme'] == 'semii'].groupby('a')['kl'].std()\n",
    "\n",
    "# generation results W2 \n",
    "mean_all_Euler_W2 = results_df[results_df['scheme'] == 'euler'].groupby('a')['w2'].mean()\n",
    "std_all_Euler_W2 = results_df[results_df['scheme'] == 'euler'].groupby('a')['w2'].std()\n",
    "mean_all_EI_W2 = results_df[results_df['scheme'] == 'semii'].groupby('a')['w2'].mean()\n",
    "std_all_EI_W2 = results_df[results_df['scheme'] == 'semii'].groupby('a')['w2'].std()\n",
    "\n",
    "# generation results VPSDE KL\n",
    "\n",
    "VPSDE_mean_KL = mean_all_Euler_KL[0.0]\n",
    "VPSDE_std_KL = std_all_Euler_KL[0.0]\n",
    "\n",
    "# generation results VPSDE W2\n",
    "\n",
    "VPSDE_mean_W2 = mean_all_Euler_W2[0.0]\n",
    "VPSDE_std_W2 = std_all_Euler_W2[0.0]\n",
    "\n",
    "# Exact score results KL \n",
    "true_score_KL_Euler = distances_df[(distances_df[\"scheme\"] == 'euler') & (distances_df[\"loss\"] == \"exact_score\") & (distances_df[\"num_steps\"]==num_steps)]['kl']\n",
    "true_score_KL_EI = distances_df[(distances_df[\"scheme\"] == 'semii') & (distances_df[\"loss\"] == \"exact_score\") & (distances_df[\"num_steps\"]==num_steps)]['kl']\n",
    "\n",
    "# Exact score results W2 \n",
    "true_score_W2_Euler = distances_df[(distances_df[\"scheme\"] == 'euler') & (distances_df[\"loss\"] == \"exact_score\") & (distances_df[\"num_steps\"]==num_steps)]['w2']\n",
    "true_score_W2_EI = distances_df[(distances_df[\"scheme\"] == 'semii') & (distances_df[\"loss\"] == \"exact_score\") & (distances_df[\"num_steps\"]==num_steps)]['w2']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ae02ee92-ff8e-45f2-9604-a242b1fd5ae6",
   "metadata": {},
   "outputs": [],
   "source": [
    "##### PLOT KL VS##### PLOT KL VS BOUND\n",
    "font = {'family' : 'sans-serif',\n",
    "        'size'   : 16}\n",
    "plt.rc('font', **font)\n",
    "cmap = cm.get_cmap('tab20c')\n",
    "mpl.rcParams['axes.facecolor'] = 'white'\n",
    "mpl.rcParams['axes.grid'] = False\n",
    "mpl.rcParams['axes.edgecolor'] = 'gray'\n",
    "\n",
    "fig, ax1 = plt.subplots()\n",
    "\n",
    "y1_color = cmap(1/20)\n",
    "ax1.plot(a_values, KL_upperbound, color=y1_color, label=\"Upper-bound\",linewidth=3)\n",
    "ax1.fill_between(a_values,\n",
    "                 ([x.item() for x in KL_upperbound]) - results_df.groupby('a')['error_approx_E2'].std().to_numpy() ,\n",
    "                 ([x.item() for x in KL_upperbound]) + results_df.groupby('a')['error_approx_E2'].std().to_numpy() ,\n",
    "                 fc=y1_color,alpha=.3)\n",
    "ax1.set_xlabel(\"Values of a\")\n",
    "ax1.set_ylabel(\"Upper bound (KL)\", color=y1_color)\n",
    "ax1.tick_params(axis='y', labelcolor=y1_color)\n",
    "\n",
    "y2_color = cmap(5/20)\n",
    "y3_color = cmap(5/20)\n",
    "color_0 = 'gray' #mean_all_Euler_KL\n",
    "color_bis = 'black'\n",
    "ax2 = ax1.twinx()\n",
    "\n",
    "ax2.plot(a_values, mean_all_Euler_KL, color=y2_color, label=\"KL($\\\\pi_{\\\\rm data}$,$\\\\hat{\\\\pi}$) (NN) \",linewidth=3)\n",
    "ax2.plot(a_values, true_score_KL_Euler, color=y2_color, label = \"KL($\\\\pi_{\\\\rm data}$,$\\\\hat{\\pi}$) (exact score)\",linewidth=3, linestyle='dotted')\n",
    "ax2.fill_between(a_values,\n",
    "                 mean_all_Euler_KL-std_all_Euler_KL,\n",
    "                 mean_all_Euler_KL+std_all_Euler_KL,\n",
    "                 fc=y2_color,alpha=.3)\n",
    "ax2.plot(a_values, np.full_like(a_values, VPSDE_mean_KL), color=y3_color, label=\"VPSDE (NN)\",linewidth=2, linestyle='--')\n",
    "ax2.fill_between(a_values,\n",
    "                 VPSDE_mean_KL-VPSDE_std_KL,\n",
    "                 VPSDE_mean_KL+VPSDE_std_KL,\n",
    "                 fc=y3_color,alpha=.2)\n",
    "ax2.set_ylabel(\"KL divergence\", color=y2_color)\n",
    "ax2.tick_params(axis='y', labelcolor=y2_color)\n",
    "\n",
    "fig.tight_layout()\n",
    "plt.legend()\n",
    "#plt.savefig(\"fig/UB_and_KL_EI_classique_covar_d\"+str(d)+\".pdf\")\n",
    "plt.show() "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3b5913ff-6aeb-4aec-84cd-bd7b5119dbea",
   "metadata": {},
   "outputs": [],
   "source": [
    "##### PLOT WASSERSTEIN ########\n",
    "font = {'family' : 'sans-serif',\n",
    "        'size'   : 16}\n",
    "plt.rc('font', **font)\n",
    "cmap = cm.get_cmap('tab20c')\n",
    "mpl.rcParams['axes.facecolor'] = 'white'\n",
    "mpl.rcParams['axes.grid'] = False\n",
    "mpl.rcParams['axes.edgecolor'] = 'gray'\n",
    "\n",
    "fig, ax1 = plt.subplots()\n",
    "\n",
    "y1_color = cmap(1/20)\n",
    "ax1.plot(a_values, W2_upperbound, color=y1_color, label=\"Upper-bound\",linewidth=3)\n",
    "#ax1.fill_between(a_values,\n",
    "#                 ([x.item() for x in W2_upperbound]) - results_df.groupby('a')['error_approx_E2'].std().to_numpy() ,\n",
    "#                 ([x.item() for x in W2_upperbound]) + results_df.groupby('a')['error_approx_E2'].std().to_numpy() ,\n",
    "#                 fc=y1_color,alpha=.3)\n",
    "ax1.set_xlabel(\"Values of a\")\n",
    "ax1.set_ylabel(\"Upper bound (W2)\", color=y1_color)\n",
    "ax1.tick_params(axis='y', labelcolor=y1_color)\n",
    "\n",
    "y2_color = cmap(5/20)\n",
    "y3_color = cmap(5/20)\n",
    "color_0 = 'gray' #mean_all_Euler_KL\n",
    "color_bis = 'black'\n",
    "ax2 = ax1.twinx()\n",
    "\n",
    "ax2.plot(a_values, mean_all_Euler_W2, color=y2_color, label=\"KL($\\\\pi_{\\\\rm data}$,$\\\\hat{\\\\pi}$) (NN) \",linewidth=3)\n",
    "ax2.plot(a_values, true_score_W2_Euler, color=y2_color, label = \"KL($\\\\pi_{\\\\rm data}$,$\\\\hat{\\pi}$) (exact score)\",linewidth=3, linestyle='dotted')\n",
    "ax2.fill_between(a_values,\n",
    "                 mean_all_Euler_W2-std_all_Euler_W2,\n",
    "                 mean_all_Euler_W2+std_all_Euler_W2,\n",
    "                 fc=y2_color,alpha=.3)\n",
    "ax2.plot(a_values, np.full_like(a_values, VPSDE_mean_W2), color=y3_color, label=\"VPSDE (NN)\",linewidth=2, linestyle='--')\n",
    "ax2.fill_between(a_values,\n",
    "                 VPSDE_mean_W2-VPSDE_std_W2,\n",
    "                 VPSDE_mean_W2+VPSDE_std_W2,\n",
    "                 fc=y3_color,alpha=.2)\n",
    "ax2.set_ylabel(\"KL divergence\", color=y2_color)\n",
    "ax2.tick_params(axis='y', labelcolor=y2_color)\n",
    "\n",
    "fig.tight_layout()\n",
    "#plt.legend()\n",
    "#plt.savefig(\"fig/UB_and_KL_EI_classique_covar_d\"+str(d)+\".pdf\")\n",
    "plt.show() "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "72c279c6-e23f-469a-853e-00f6c3f41f9c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# to train a = 2 onwards\n",
    "\n",
    "n_epochs = 500\n",
    "\n",
    "score_theta_trained = [ [] for _ in a_values ]   # a_values, replicates\n",
    "for k, a in enumerate(a_values):\n",
    "    print(f\"optimization for a = {a}\")\n",
    "    sde.beta.change_a(a)\n",
    "    score_theta_trained[k] = [ ]\n",
    "    for j in range(rep_size):\n",
    "        network = decoder.Decoder(sde.d, mid_features, num_layers).to(device)\n",
    "        optimizer = Adam(network.parameters(), lr=learning_rate)\n",
    "        loss = diff.loss_conditional(network, sde) # to choose when the score function is not analytically available\n",
    "        #loss = diff.loss_explicit(network, sde, diff.explicit_score(sde, training_distribution)) # to choose in the Gaussian case for training with the analytical score function\n",
    "        diff.train(loss, dataloader, n_epochs=n_epochs, optimizer=optimizer)\n",
    "        score_theta_trained[k].append(network)\n",
    "        torch.save(network.state_dict(), f'models/gaussian_iso/d'+str(d)+f'_cond_ep'+str(n_epochs)+f'/model_{k}_{j}_d'+str(d)+'.pt')\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5376ab9a-7caa-402b-87de-b144e06a4870",
   "metadata": {},
   "outputs": [],
   "source": [
    "'''\n",
    "### GAUSSIAN CASE\n",
    "\n",
    "simulation_results = []\n",
    "\n",
    "for k, a in tqdm(enumerate(a_values)):\n",
    "    print(f\"simulation and M for a = {a}\")\n",
    "    sde.beta.change_a(a)\n",
    "    exp_score = diff.explicit_score(sde, training_distribution)\n",
    "\n",
    "    for j, score_theta in enumerate(score_theta_trained[k]):\n",
    "        #error_approx_E2 = func.compute_E2(training_distribution, sde, score_theta, exp_score, num_steps, num_mc)\n",
    "        error_approx_E2, error_approx_sup_L2 = func.compute_E2(training_distribution, sde, score_theta, exp_score, num_steps, num_mc)\n",
    "\n",
    "        for scheme_name, scheme_sampler in [(\"euler\", sp.Euler_Maruyama_discr_sampler), (\"semii\", sp.EI_discr_sampler)]:\n",
    "            sample = scheme_sampler(init, sde, score_theta, num_steps)\n",
    "            kl_value = kl(training_distribution, empirical(sample))\n",
    "            w2_value = w2(training_distribution, empirical(sample))\n",
    "\n",
    "            simulation_results.append({\n",
    "                \"a\": a,\n",
    "                \"replication\": j,\n",
    "                \"scheme\": scheme_name,\n",
    "                \"error_approx_E2\": error_approx_E2,\n",
    "                \"error_approx_sup_L2\": error_approx_sup_L2,\n",
    "                \"kl\": kl_value,\n",
    "                \"w2\": w2_value,\n",
    "            })\n",
    "\n",
    "simulation_df = pd.DataFrame(simulation_results)\n",
    "\n",
    "#save\n",
    "#file_path = os.path.join(f'models/gaussian_iso/d'+str(d)+f'_cond_ep'+str(n_epochs)+'_pkl', 'simulation_df.pkl')\n",
    "#simulation_df.to_pickle(file_path)\n",
    "#load \n",
    "#file_path_load = 'models/gaussian_iso/d50_explicit_pkl/simulation_df.pkl'\n",
    "#simulation_df = pd.read_pickle(file_path_load)\n",
    "'''"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "166ac3d2-88f0-4cce-a7d3-232b25b52290",
   "metadata": {},
   "outputs": [],
   "source": [
    "#save\n",
    "#file_path = os.path.join(f'models/gaussian_iso/d'+str(d)+f'_cond_ep'+str(n_epochs)+'_pkl', 'simulation_df.pkl')\n",
    "#simulation_df.to_pickle(file_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a28927a7-52e7-4c1e-8929-c0ec39d55dfd",
   "metadata": {},
   "outputs": [],
   "source": [
    "##### TO LOAD PRETRAINED MODELS\n",
    "score_theta_trained = [ [] for _ in a_values ]   # a_values, replicates\n",
    "for k, a in enumerate(a_values):\n",
    "    score_theta_trained[k] = [ ]\n",
    "    for j in range(rep_size):\n",
    "        network = decoder.Decoder(sde.d, mid_features, num_layers) \n",
    "        network.load_state_dict(torch.load(f'models/gaussian_iso/d'+str(d)+f'_cond_ep'+str(n_epochs)+f'/model_{k}_{j}_d'+str(d)+'.pt'))\n",
    "        score_theta_trained[k].append(network)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b8887406-36be-419e-b63b-e4d1629c4451",
   "metadata": {},
   "outputs": [],
   "source": [
    "n_epochs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e68c96d0-62ff-4e41-8395-ec6ad5ef6cd4",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e0008282-f903-4e3a-88c6-ba05e51f9baa",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9f06bbb1-bf12-4db7-91df-d701a4117e86",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1c1d4301-bda4-4248-bafc-bf1de08ffb64",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ef55f30b-48aa-4aef-a76c-ffebabe257b3",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f93ed3a1-b7e0-49b7-92e9-da4ed2bd0e6c",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "22307627-712f-48b3-b008-a431cbc4690a",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5a40c8b9-9a42-4f4a-821a-5f8b9713222e",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d55f275c-0df2-463d-9c6d-c88ed9067c94",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aeb6cad8-e03b-48bf-9f3d-a9c4a90f9459",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "11e2781c-d7e8-49f6-9e17-428cc919ce69",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7f4d2864-cac4-4aed-8e24-92c611905a73",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7e3cc434-e374-4e03-8e23-7ff56f3dbfa2",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f273c08e-25da-4438-8685-003afb0caa15",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1af8928c-2c22-4ea8-88da-3bcea9f84a81",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "23f41a74-0848-489b-b236-56c86edc9c86",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ed06ee60-4f42-4a4f-b89c-95b3133aa33b",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0464926f-2fca-46fa-9443-2f299e88a8b3",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9bd2dbd4-c132-4633-8f36-48b0860689c3",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4fab3be1-e06b-4756-9d3e-858b2e4e31b9",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c998bacc-94fc-40be-8da3-d00f329a3a41",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cab89999-2cda-45c0-915b-63c28cef6510",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b3e2bcbf-bbc3-448b-9fdf-0011c1fcd83e",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "ac181f35-49e0-4f4e-8c2b-61eedc123baa",
   "metadata": {},
   "source": [
    "## Anisotropic Explicit"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "48d172b5-c4f0-4ce0-8b8a-0fb8ba82d3c4",
   "metadata": {},
   "source": [
    "### ANISOTROPIC EXPLICIT"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2ded414d-372f-44e9-a9fa-12ba86a6f9b7",
   "metadata": {},
   "outputs": [],
   "source": [
    "########### SET Training data PARAMETERS #################\n",
    "n = 10000\n",
    "d = 50\n",
    "mu = torch.ones(d, device = device)\n",
    "\n",
    "### Heteroscedastic case\n",
    "blocks = torch.tensor([1, .01, 1., 1., 1.], device=device)\n",
    "diag_values = torch.repeat_interleave(blocks, d // len(blocks))\n",
    "SIGMA = torch.diag(diag_values)\n",
    "\n",
    "training_distribution = func.gaussian(d, mu, SIGMA)\n",
    "training_sample = training_distribution.generate_sample(n)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "897c453b-d795-4672-bd36-d2af7958b477",
   "metadata": {},
   "outputs": [],
   "source": [
    "training_distribution.compute_L0()\n",
    "training_distribution.compute_C0()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ff9506a7-28d8-4f53-ae84-376c0b892120",
   "metadata": {},
   "outputs": [],
   "source": [
    "########### SET DIFFUSION AND NOISING PARAMETERS #################\n",
    "sigma_inv = 1.\n",
    "beta_min = 0.1\n",
    "beta_max = 20\n",
    "a = torch.tensor(0.)\n",
    "T = 1\n",
    "\n",
    "# Set noising function and parameters\n",
    "a_values = np.linspace(-10.,10.,21) #set the parameter values for the beta_parametric function\n",
    "beta = func.beta_parametric(a, T, beta_min, beta_max)\n",
    "#beta = func.beta_cosine(T, 0.003)\n",
    "\n",
    "sde = diff.forward_VPSDE(d, beta, sigma_inv, T, device = device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5931cb56-d576-48e0-a804-096d9145aca9",
   "metadata": {},
   "outputs": [],
   "source": [
    "################################# SET SAMPLING PARAMETERS #################################\n",
    "sample_batch_size = 10000 # size of the sample generated\n",
    "#init = sde.final.generate_sample(sample_batch_size)\n",
    "num_steps = 500 \n",
    "#xbarT_euler = sp.Euler_Maruyama_discr_sampler(init, sde, diff.explicit_score(sde, training_distribution), num_steps)\n",
    "#xbarT_semii = sp.EI_discr_sampler(init, sde, diff.explicit_score(sde, training_distribution), num_steps)\n",
    "num_steps_vec = np.array([num_steps//4, num_steps//2, num_steps ]) #to adjust as desired"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1ae34d3c-0279-4df2-a4ed-65bebad82c5c",
   "metadata": {},
   "outputs": [],
   "source": [
    "'''\n",
    "##### EXACT SCORE SIMULATION GAUSSIAN CASE ####\n",
    "#num_steps_vec = np.array([num_steps]) #to adjust as desired\n",
    "distances = []\n",
    "for i, ns in tqdm(enumerate(num_steps_vec)):\n",
    "    for j, a in enumerate(a_values):\n",
    "        sde.beta.change_a(a)\n",
    "        score_theta = diff.explicit_score(sde, training_distribution)\n",
    "        for k, scheme in enumerate([sp.Euler_Maruyama_discr_sampler, sp.EI_discr_sampler]):\n",
    "            sample = scheme(init, sde, score_theta, ns)\n",
    "            distances.append({\n",
    "                \"a\": a, \"num_steps\": ns, \n",
    "                \"scheme\": \"euler\" if k==0 else \"semii\", \n",
    "                \"loss\": \"exact_score\", \n",
    "                \"kl\": kl(training_distribution, empirical(sample)), \n",
    "                \"w2\": w2(training_distribution, empirical(sample)),\n",
    "            })\n",
    "distances_df = pd.DataFrame(distances)\n",
    "'''"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7cc31ba9-f547-4ac7-9d83-227f09009b0c",
   "metadata": {},
   "outputs": [],
   "source": [
    "'''\n",
    "#save\n",
    "#file_path = os.path.join('models/gaussian_aniso/d50_explicit_data', 'distances_df.pkl')\n",
    "#distances_df.to_pickle(file_path)\n",
    "'''\n",
    "#load \n",
    "file_path_load = 'models/gaussian_aniso/d50_explicit_pkl/distances_df.pkl'\n",
    "distances_df = pd.read_pickle(file_path_load)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b22af410-80ee-4b79-8ffd-2e2657d4e108",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axs = plt.subplots(figsize=(8,4), ncols=2, sharey=True, sharex = True)\n",
    "axs[0].set_title(\"KL divergence, exact score, Euler scheme\")\n",
    "axs[1].set_title(\"KL divergence, exact score, Semi-Implicit scheme\")\n",
    "\n",
    "for k, scheme in enumerate([\"euler\", \"semii\"]):\n",
    "    for ns in num_steps_vec:\n",
    "        select = (distances_df[\"scheme\"] == scheme) & (distances_df[\"loss\"] == \"exact_score\") & (distances_df[\"num_steps\"]==ns)\n",
    "        res_kl = distances_df[select].groupby(\"a\")[\"kl\"].mean()\n",
    "        axs[k].plot(a_values, res_kl, label=f\"num_steps = {ns}\")\n",
    "    axs[k].legend()\n",
    "    axs[k].set_xlabel(\"a value\")\n",
    "\n",
    "axs[0].set_ylabel(\"KL divergence\")\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "122b9953-d3ab-4efc-a604-ce2d62f3864c",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axs = plt.subplots(figsize=(8,4), ncols=2, sharey=True, sharex = True)\n",
    "axs[0].set_title(\"W2 distance, exact score, Euler scheme\")\n",
    "axs[1].set_title(\"W2, exact score, Semi-Implicit scheme\")\n",
    "\n",
    "for k, scheme in enumerate([\"euler\", \"semii\"]):\n",
    "    for ns in num_steps_vec:\n",
    "        select = (distances_df[\"scheme\"] == scheme) & (distances_df[\"loss\"] == \"exact_score\") & (distances_df[\"num_steps\"]==ns)\n",
    "        res_kl = distances_df[select].groupby(\"a\")[\"w2\"].mean()\n",
    "        axs[k].plot(a_values, res_kl, label=f\"num_steps = {ns}\")\n",
    "    axs[k].legend()\n",
    "    axs[k].set_xlabel(\"a value\")\n",
    "\n",
    "axs[0].set_ylabel(\"W2 distance\")\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0cc0c0fe-efd9-4bfe-8298-f9a13edf5faf",
   "metadata": {},
   "outputs": [],
   "source": [
    "######## SELECT TRAINING PARAMETERS W.R.T THE DIMENSION ######################\n",
    "if d==5: \n",
    "    # optimisation parameters\n",
    "    n_epochs = 2 \n",
    "    learning_rate = 1.0e-4\n",
    "    \n",
    "elif d==10: \n",
    "    # optimisation parameters\n",
    "    n_epochs = 30\n",
    "    learning_rate = 1.0e-4  \n",
    "    \n",
    "elif d==25: \n",
    "    # optimisation parameters\n",
    "    n_epochs = 50\n",
    "    learning_rate = 1.0e-3\n",
    "    \n",
    "elif d==50: \n",
    "    # optimisation parameters\n",
    "    n_epochs = 150\n",
    "    learning_rate = 1.0e-3\n",
    "\n",
    "# network_parameters\n",
    "mid_features = 256\n",
    "num_layers = 3\n",
    "batch_size = 64\n",
    "\n",
    "# Monte Carlo estimate samples (for E2 computation in the KL bound)\n",
    "num_mc = 500\n",
    "\n",
    "# number of runs (for replication)\n",
    "rep_size = 10\n",
    "\n",
    "# load data\n",
    "dataloader = DataLoader(training_sample, batch_size=batch_size, shuffle=True)\n",
    "\n",
    "# number of discretisation steps\n",
    "num_step = 500"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8efb430c-0658-4e31-8e8d-cbecaa55a12c",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "###\n",
    "## TRAINING NN FOR DIFFERENT NOISE SCHEDULES\n",
    "###\n",
    "'''\n",
    "score_theta_trained = [ [] for _ in a_values ]   # a_values, replicates\n",
    "for k, a in enumerate(a_values):\n",
    "    print(f\"optimization for a = {a}\")\n",
    "    sde.beta.change_a(a)\n",
    "    score_theta_trained[k] = [ ]\n",
    "    for j in range(rep_size):\n",
    "        network = decoder.Decoder(sde.d, mid_features, num_layers).to(device)\n",
    "        optimizer = Adam(network.parameters(), lr=learning_rate)\n",
    "        loss = diff.loss_conditional(network, sde) # to choose when the score function is not analytically available\n",
    "        #loss = diff.loss_explicit(network, sde, diff.explicit_score(sde, training_distribution)) # to choose in the Gaussian case for training with the analytical score function\n",
    "        diff.train(loss, dataloader, n_epochs=n_epochs, optimizer=optimizer)\n",
    "        score_theta_trained[k].append(network)\n",
    "        #torch.save(network.state_dict(), f'model/d'+str(d)+f'/model_{k}_{j}_d'+str(d)+'.pt')\n",
    "\n",
    "'''\n",
    "##### TO LOAD PRETRAINED MODELS\n",
    "score_theta_trained = [ [] for _ in a_values ]   # a_values, replicates\n",
    "for k, a in enumerate(a_values):\n",
    "    score_theta_trained[k] = [ ]\n",
    "    for j in range(rep_size):\n",
    "        network = decoder.Decoder(sde.d, mid_features, num_layers) \n",
    "        network.load_state_dict(torch.load(f'models/gaussian_aniso/d'+str(d)+f'_explicit_ep'+str(n_epochs)+f'/model_{k}_{j}_d'+str(d)+'.pt'))\n",
    "        score_theta_trained[k].append(network)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6ad14edc-1770-4306-befd-8b4e3d7e4f9b",
   "metadata": {},
   "outputs": [],
   "source": [
    "### GAUSSIAN CASE\n",
    "'''\n",
    "simulation_results = []\n",
    "\n",
    "for k, a in tqdm(enumerate(a_values)):\n",
    "    print(f\"simulation and M for a = {a}\")\n",
    "    sde.beta.change_a(a)\n",
    "    exp_score = diff.explicit_score(sde, training_distribution)\n",
    "\n",
    "    for j, score_theta in enumerate(score_theta_trained[k]):\n",
    "        error_approx_E2, error_approx_sup_L2 = func.compute_E2(training_distribution, sde, score_theta, exp_score, num_steps, num_mc)\n",
    "\n",
    "        for scheme_name, scheme_sampler in [(\"euler\", sp.Euler_Maruyama_discr_sampler), (\"semii\", sp.EI_discr_sampler)]:\n",
    "            sample = scheme_sampler(init, sde, score_theta, num_steps)\n",
    "            kl_value = kl(training_distribution, empirical(sample))\n",
    "            w2_value = w2(training_distribution, empirical(sample))\n",
    "\n",
    "            simulation_results.append({\n",
    "                \"a\": a,\n",
    "                \"replication\": j,\n",
    "                \"scheme\": scheme_name,\n",
    "                \"error_approx_E2\": error_approx_E2,\n",
    "                \"error_approx_sup_L2\": error_approx_sup_L2,\n",
    "                \"kl\": kl_value,\n",
    "                \"w2\": w2_value,\n",
    "            })\n",
    "\n",
    "simulation_df = pd.DataFrame(simulation_results)\n",
    "'''"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "39d14014-8d2e-421a-be57-5da05e15e010",
   "metadata": {},
   "outputs": [],
   "source": [
    "#save\n",
    "#file_path = os.path.join('models/gaussian_aniso/d50_explicit_pkl', 'simulation_df.pkl')\n",
    "#simulation_df.to_pickle(file_path)\n",
    "#load \n",
    "file_path_load = 'models/gaussian_aniso/d50_explicit_pkl/simulation_df.pkl'\n",
    "simulation_df = pd.read_pickle(file_path_load)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f9ae33f2-9aa0-4bf3-b1eb-1d09fa3cc23b",
   "metadata": {},
   "outputs": [],
   "source": [
    "#epsilon = simulation_df.groupby('a')['error_approx_sup_L2'].mean()\n",
    "#error_results = []\n",
    "\n",
    "#for k, a in tqdm(enumerate(a_values)):\n",
    "#    print(f\"bound for a = {a}\")\n",
    "#    sde.beta.change_a(a)\n",
    "#    error_mixing_E1 = func.compute_E1(training_distribution, sde)\n",
    "#    error_discr_E3 = func.compute_E3(training_distribution, sde, num_steps)\n",
    "#    w2_error = func.compute_w2_bound(training_distribution, training_sample, sde, num_steps, epsilon[a])\n",
    "\n",
    "#    error_results.append({\n",
    "#        \"a\": a,\n",
    "#        \"error_mixing_E1\": error_mixing_E1,\n",
    "#        \"error_discr_E3\": error_discr_E3,\n",
    "#        \"w2_error\": w2_error,\n",
    "#    })\n",
    "\n",
    "#errors_df = pd.DataFrame(error_results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e688fc94-d895-47de-949a-fe22aa6a4d72",
   "metadata": {},
   "outputs": [],
   "source": [
    "#save\n",
    "#file_path = os.path.join('models/gaussian_aniso/d50_explicit_pkl', 'errors_df.pkl')\n",
    "#errors_df.to_pickle(file_path)\n",
    "#load \n",
    "file_path_load = 'models/gaussian_aniso/d50_explicit_pkl/errors_df.pkl'\n",
    "errors_df = pd.read_pickle(file_path_load)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7274e60e-704a-477a-b724-7c6c8e80deaf",
   "metadata": {},
   "outputs": [],
   "source": [
    "results_df = pd.merge(simulation_df, errors_df, on=\"a\", how=\"left\")\n",
    "\n",
    "# comutation of the upperbound KL\n",
    "error_mixing_E1 = results_df.groupby('a')['error_mixing_E1'].first()\n",
    "error_approx_E2 = results_df.groupby('a')['error_approx_E2'].mean()\n",
    "error_discr_E3 = results_df.groupby('a')['error_discr_E3'].first() \n",
    "KL_upperbound = error_mixing_E1 + error_approx_E2 + error_discr_E3\n",
    "\n",
    "# computation of the W2 upperbound\n",
    "W2_upperbound = results_df.groupby('a')['w2_error'].mean()\n",
    "\n",
    "# generation results KL \n",
    "mean_all_Euler_KL = results_df[results_df['scheme'] == 'euler'].groupby('a')['kl'].mean()\n",
    "std_all_Euler_KL = results_df[results_df['scheme'] == 'euler'].groupby('a')['kl'].std()\n",
    "mean_all_EI_KL = results_df[results_df['scheme'] == 'semii'].groupby('a')['kl'].mean()\n",
    "std_all_EI_KL = results_df[results_df['scheme'] == 'semii'].groupby('a')['kl'].std()\n",
    "\n",
    "# generation results W2 \n",
    "mean_all_Euler_W2 = results_df[results_df['scheme'] == 'euler'].groupby('a')['w2'].mean()\n",
    "std_all_Euler_W2 = results_df[results_df['scheme'] == 'euler'].groupby('a')['w2'].std()\n",
    "mean_all_EI_W2 = results_df[results_df['scheme'] == 'semii'].groupby('a')['w2'].mean()\n",
    "std_all_EI_W2 = results_df[results_df['scheme'] == 'semii'].groupby('a')['w2'].std()\n",
    "\n",
    "# generation results VPSDE KL\n",
    "\n",
    "VPSDE_mean_KL = mean_all_Euler_KL[0.0]\n",
    "VPSDE_std_KL = std_all_Euler_KL[0.0]\n",
    "\n",
    "# generation results VPSDE W2\n",
    "\n",
    "VPSDE_mean_W2 = mean_all_Euler_W2[0.0]\n",
    "VPSDE_std_W2 = std_all_Euler_W2[0.0]\n",
    "\n",
    "# Exact score results KL \n",
    "true_score_KL_Euler = distances_df[(distances_df[\"scheme\"] == 'euler') & (distances_df[\"loss\"] == \"exact_score\") & (distances_df[\"num_steps\"]==num_steps)]['kl']\n",
    "true_score_KL_EI = distances_df[(distances_df[\"scheme\"] == 'semii') & (distances_df[\"loss\"] == \"exact_score\") & (distances_df[\"num_steps\"]==num_steps)]['kl']\n",
    "\n",
    "# Exact score results W2 \n",
    "true_score_W2_Euler = distances_df[(distances_df[\"scheme\"] == 'euler') & (distances_df[\"loss\"] == \"exact_score\") & (distances_df[\"num_steps\"]==num_steps)]['w2']\n",
    "true_score_W2_EI = distances_df[(distances_df[\"scheme\"] == 'semii') & (distances_df[\"loss\"] == \"exact_score\") & (distances_df[\"num_steps\"]==num_steps)]['w2']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ff88c107-de49-41e1-ac6a-30fd2219d399",
   "metadata": {},
   "outputs": [],
   "source": [
    "##### PLOT KL VS##### PLOT KL VS BOUND\n",
    "font = {'family' : 'sans-serif',\n",
    "        'size'   : 16}\n",
    "plt.rc('font', **font)\n",
    "cmap = cm.get_cmap('tab20c')\n",
    "mpl.rcParams['axes.facecolor'] = 'white'\n",
    "mpl.rcParams['axes.grid'] = False\n",
    "mpl.rcParams['axes.edgecolor'] = 'gray'\n",
    "\n",
    "fig, ax1 = plt.subplots()\n",
    "\n",
    "y1_color = cmap(1/20)\n",
    "ax1.plot(a_values, KL_upperbound, color=y1_color, label=\"Upper-bound\",linewidth=3)\n",
    "ax1.fill_between(a_values,\n",
    "                 ([x.item() for x in KL_upperbound]) - results_df.groupby('a')['error_approx_E2'].std().to_numpy() ,\n",
    "                 ([x.item() for x in KL_upperbound]) + results_df.groupby('a')['error_approx_E2'].std().to_numpy() ,\n",
    "                 fc=y1_color,alpha=.3)\n",
    "ax1.set_xlabel(\"Values of a\")\n",
    "ax1.set_ylabel(\"Upper bound (KL)\", color=y1_color)\n",
    "ax1.tick_params(axis='y', labelcolor=y1_color)\n",
    "\n",
    "y2_color = cmap(5/20)\n",
    "y3_color = cmap(5/20)\n",
    "color_0 = 'gray' #mean_all_Euler_KL\n",
    "color_bis = 'black'\n",
    "ax2 = ax1.twinx()\n",
    "\n",
    "ax2.plot(a_values, mean_all_Euler_KL, color=y2_color, label=\"KL($\\\\pi_{\\\\rm data}$,$\\\\hat{\\\\pi}$) (NN) \",linewidth=3)\n",
    "ax2.plot(a_values, true_score_KL_Euler, color=y2_color, label = \"KL($\\\\pi_{\\\\rm data}$,$\\\\hat{\\pi}$) (exact score)\",linewidth=3, linestyle='dotted')\n",
    "ax2.fill_between(a_values,\n",
    "                 mean_all_Euler_KL-std_all_Euler_KL,\n",
    "                 mean_all_Euler_KL+std_all_Euler_KL,\n",
    "                 fc=y2_color,alpha=.3)\n",
    "ax2.plot(a_values, np.full_like(a_values, VPSDE_mean_KL), color=y3_color, label=\"VPSDE (NN)\",linewidth=2, linestyle='--')\n",
    "ax2.fill_between(a_values,\n",
    "                 VPSDE_mean_KL-VPSDE_std_KL,\n",
    "                 VPSDE_mean_KL+VPSDE_std_KL,\n",
    "                 fc=y3_color,alpha=.2)\n",
    "ax2.set_ylabel(\"KL divergence\", color=y2_color)\n",
    "ax2.tick_params(axis='y', labelcolor=y2_color)\n",
    "\n",
    "fig.tight_layout()\n",
    "plt.legend()\n",
    "#plt.savefig(\"fig/UB_and_KL_EI_classique_covar_d\"+str(d)+\".pdf\")\n",
    "plt.show() "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "346b5e0d-e1c8-42af-8149-6eb57267216c",
   "metadata": {},
   "outputs": [],
   "source": [
    "##### PLOT WASSERSTEIN ########\n",
    "font = {'family' : 'sans-serif',\n",
    "        'size'   : 16}\n",
    "plt.rc('font', **font)\n",
    "cmap = cm.get_cmap('tab20c')\n",
    "mpl.rcParams['axes.facecolor'] = 'white'\n",
    "mpl.rcParams['axes.grid'] = False\n",
    "mpl.rcParams['axes.edgecolor'] = 'gray'\n",
    "\n",
    "fig, ax1 = plt.subplots()\n",
    "\n",
    "y1_color = cmap(1/20)\n",
    "ax1.plot(a_values, W2_upperbound, color=y1_color, label=\"Upper-bound\",linewidth=3)\n",
    "ax1.fill_between(a_values,\n",
    "                 ([x for x in W2_upperbound]) - simulation_df.groupby('a')['error_approx_sup_L2'].std().to_numpy() ,\n",
    "                 ([x for x in W2_upperbound]) + simulation_df.groupby('a')['error_approx_sup_L2'].std().to_numpy() ,\n",
    "                 fc=y1_color,alpha=.3)\n",
    "ax1.set_xlabel(\"Values of a\")\n",
    "ax1.set_ylabel(\"Upper bound (W2)\", color=y1_color)\n",
    "ax1.tick_params(axis='y', labelcolor=y1_color)\n",
    "\n",
    "y2_color = cmap(5/20)\n",
    "y3_color = cmap(5/20)\n",
    "color_0 = 'gray' #mean_all_Euler_KL\n",
    "color_bis = 'black'\n",
    "ax2 = ax1.twinx()\n",
    "\n",
    "ax2.plot(a_values, mean_all_Euler_W2, color=y2_color, label= r\"$\\mathcal{W}_2(\\pi_{\\rm data},\\hat{\\pi})$ (NN) \",linewidth=3)\n",
    "ax2.plot(a_values, true_score_W2_Euler, color=y2_color, label = r\"$\\mathcal{W}_2(\\pi_{\\rm data}$,$\\hat{\\pi}$) (exact score)\",linewidth=3, linestyle='dotted')\n",
    "ax2.fill_between(a_values,\n",
    "                 mean_all_Euler_W2-std_all_Euler_W2,\n",
    "                 mean_all_Euler_W2+std_all_Euler_W2,\n",
    "                 fc=y2_color,alpha=.3)\n",
    "ax2.plot(a_values, np.full_like(a_values, VPSDE_mean_W2), color=y3_color, label=\"VPSDE (NN)\",linewidth=2, linestyle='--')\n",
    "ax2.fill_between(a_values,\n",
    "                 VPSDE_mean_W2-VPSDE_std_W2,\n",
    "                 VPSDE_mean_W2+VPSDE_std_W2,\n",
    "                 fc=y3_color,alpha=.2)\n",
    "ax2.set_ylabel(r\"$\\mathcal{W}_2$ distance\", color=y2_color)\n",
    "ax2.tick_params(axis='y', labelcolor=y2_color)\n",
    "\n",
    "fig.tight_layout()\n",
    "\n",
    "plt.legend(loc = 'upper center')\n",
    "#plt.savefig(\"anisotropic_num_steps_500_W_2_C_t+L_t_corr_M1.pdf\")\n",
    "plt.show() "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5046fb61-aad0-4ff8-99dc-b137e8c3746e",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fe37b59a-d768-4747-b3f6-3fd943eea4da",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "53915e37-1c8c-4340-802f-263adaa42eab",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bd5ef7e6-46d0-41ca-9499-b06b00b261b6",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "afc06058-39b1-4c8e-89be-016d4eef416e",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8bc158a1-00cc-4df9-9360-f17582505e57",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0414d492-8a54-444f-b7fc-3a51ceb2b47e",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5be890b6-118e-40fe-a916-8f64a46eb890",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c8ef059f-27ae-4369-9540-7a52a7435f0d",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b50a0d75-dc44-4b7d-b189-8ddca524dae1",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d73bf559-7235-4839-a67e-a8c039444988",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ad6c1d84-4def-4947-8515-39bea40db8ef",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "409d9ceb-2f63-48e3-9b52-711d7131fd18",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cc0acf10-6a34-4cd7-80cf-9106e17f1e08",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9ebfc8dc-494e-4239-be87-20c6106650a8",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a71d0773-7ab5-4491-b9b8-cd1ebe39813b",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "66849ce6-3a1f-4e47-bf10-34949a356902",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b01c3b21-4221-4cde-964e-cbe8852588de",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f221ee66-dff7-405e-8d61-af39f1c42f2c",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "35b58493-72ca-48bd-9660-8adcda27c90e",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b2f230f1-a1b0-480d-8e93-1641843bbbc6",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5120b1ca-5a72-4f4b-96d6-43fb41ba0333",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e5e0284a-c4cd-4c2a-ae42-b20e911bb7b5",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5cf6043e-599a-47f2-8c2e-46e90c349d41",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "35f38341-9f57-4635-9aa4-6d88d9ceeafd",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b5806fbe-03de-4e26-923c-3d3cb400da85",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b846996d-f44e-4779-b7b5-ffba493b7648",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "763b84ed-4936-4b6a-8f28-aac231933b97",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "79507b4c-1d01-43f5-ae99-8b25af7e9593",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b9a981f0-00fd-43f5-8a49-2754db455096",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9da01558-4d1d-4e4b-9f2b-92a852664f7d",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "63406941-8403-4435-a9b7-66ec08e51f64",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8ca62838-9446-44e9-b314-f124c34c524a",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "311dbbb2-95bd-4238-b3f4-ef09ba484fa0",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "f323a2e0-a4ab-4b97-b399-d8b817ac3b1e",
   "metadata": {},
   "source": [
    "### ANISOTROPIC EXPLICIT RESCALE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "7d06cb73-3978-4a9c-80e5-0f01a4c0d514",
   "metadata": {},
   "outputs": [],
   "source": [
    "########### SET Training data PARAMETERS #################\n",
    "n = 10000\n",
    "d = 50\n",
    "mu = torch.ones(d, device = device)\n",
    "\n",
    "### Heteroscedastic case\n",
    "blocks = torch.tensor([1., .01, 1., 1., 1.], device=device) \n",
    "diag_values = torch.repeat_interleave(blocks, d // len(blocks))\n",
    "SIGMA = torch.diag(diag_values)\n",
    "\n",
    "training_distribution = func.gaussian(d, mu, SIGMA)\n",
    "training_sample = training_distribution.generate_sample(n)\n",
    "\n",
    "rescale = np.sqrt(2) #set diag to 1/rescale**2\n",
    "training_sample_rescale, mean ,std  = func.normalize(training_sample, rescale)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9b02ed78-895c-4887-9db9-9905014f0c64",
   "metadata": {},
   "outputs": [],
   "source": [
    "### Check eigenvalues of the rescaled version\n",
    "#cov = func.compute_cov_matrix(training_sample_rescale)\n",
    "#eigenvalues = torch.linalg.eigvals(cov)\n",
    "#print(torch.min(torch.abs(eigenvalues)), torch.max(torch.abs(eigenvalues)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "506b2efa-d1b3-48dc-aad0-751dd69c1367",
   "metadata": {},
   "outputs": [],
   "source": [
    "### Check the unnormalize function\n",
    "Z = func.unnormalize(training_sample_rescale,mean, std, rescale)\n",
    "#torch.mean(Z, dim = 0)\n",
    "#func.compute_cov_matrix(Z)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "d67d0917-db51-4f83-9ef7-9851a15e728d",
   "metadata": {},
   "outputs": [],
   "source": [
    "### Actual distribution of the rescaled version\n",
    "cov_scale = torch.eye(d)/(rescale)**2 #func.compute_cov_matrix(training_sample_rescale)\n",
    "mean_scale = torch.zeros(d) #torch.mean(training_sample_rescale, dim=0)\n",
    "dist =  gaussian(d,mean_scale, cov_scale)\n",
    "\n",
    "### Empirical distribution of the recaled version\n",
    "\n",
    "#cov_emp = func.compute_cov_matrix(training_sample_rescale)\n",
    "#mean_emp = torch.mean(training_sample_rescale, dim=0) \n",
    "#dist_emp = gaussian(sde.d,mean_emp, cov_emp)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "38868926-ac85-4c49-b3da-6786dec482a7",
   "metadata": {},
   "outputs": [],
   "source": [
    "########### SET DIFFUSION AND NOISING PARAMETERS #################\n",
    "sigma_inv = 1.\n",
    "beta_min = 0.1\n",
    "beta_max = 20\n",
    "a = torch.tensor(0.)\n",
    "T = 1\n",
    "\n",
    "# Set noising function and parameters\n",
    "a_values = np.linspace(-10.,10.,21) #set the parameter values for the beta_parametric function\n",
    "beta = func.beta_parametric(a, T, beta_min, beta_max)\n",
    "sde = diff.forward_VPSDE(d, beta, sigma_inv, T, device = device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "a196e5c5-4e14-4df8-98c6-799a13d0fed6",
   "metadata": {},
   "outputs": [],
   "source": [
    "################################# SET SAMPLING PARAMETERS #################################\n",
    "sample_batch_size = 10000 # size of the sample generated\n",
    "init = sde.final.generate_sample(sample_batch_size)\n",
    "num_steps = 500"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "1acee2b6-e442-4989-8ca0-51ec61258916",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "1it [14:09, 849.43s/it]\n"
     ]
    }
   ],
   "source": [
    "##### EXACT SCORE SIMULATION GAUSSIAN CASE ####\n",
    "num_steps_vec = np.array([num_steps]) #to adjust as desired\n",
    "distances = []\n",
    "for i, ns in tqdm(enumerate(num_steps_vec)):\n",
    "    for j, a in enumerate(a_values):\n",
    "        sde.beta.change_a(a)\n",
    "        score_theta = diff.explicit_score(sde, dist)\n",
    "        for k, scheme in enumerate([sp.Euler_Maruyama_discr_sampler]):#, sp.EI_discr_sampler]):\n",
    "            sample = scheme(init, sde, score_theta, ns)\n",
    "            sample = func.unnormalize(sample, mean,std, rescale)\n",
    "            distances.append({\n",
    "                \"a\": a, \"num_steps\": ns, \n",
    "                \"scheme\": \"euler\" if k==0 else \"semii\", \n",
    "                \"loss\": \"exact_score\", \n",
    "                #\"kl\": kl(dist, empirical(sample)), \n",
    "                \"w2\": w2(training_distribution, empirical(sample)),\n",
    "            })\n",
    "distances_df = pd.DataFrame(distances)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "2f5a9001-88c1-4520-b231-e8763f81b82a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>a</th>\n",
       "      <th>num_steps</th>\n",
       "      <th>scheme</th>\n",
       "      <th>loss</th>\n",
       "      <th>w2</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>-10.0</td>\n",
       "      <td>500</td>\n",
       "      <td>euler</td>\n",
       "      <td>exact_score</td>\n",
       "      <td>0.230114</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>-9.0</td>\n",
       "      <td>500</td>\n",
       "      <td>euler</td>\n",
       "      <td>exact_score</td>\n",
       "      <td>0.230397</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>-8.0</td>\n",
       "      <td>500</td>\n",
       "      <td>euler</td>\n",
       "      <td>exact_score</td>\n",
       "      <td>0.227390</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>-7.0</td>\n",
       "      <td>500</td>\n",
       "      <td>euler</td>\n",
       "      <td>exact_score</td>\n",
       "      <td>0.222234</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>-6.0</td>\n",
       "      <td>500</td>\n",
       "      <td>euler</td>\n",
       "      <td>exact_score</td>\n",
       "      <td>0.225810</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>-5.0</td>\n",
       "      <td>500</td>\n",
       "      <td>euler</td>\n",
       "      <td>exact_score</td>\n",
       "      <td>0.238143</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>-4.0</td>\n",
       "      <td>500</td>\n",
       "      <td>euler</td>\n",
       "      <td>exact_score</td>\n",
       "      <td>0.238609</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>-3.0</td>\n",
       "      <td>500</td>\n",
       "      <td>euler</td>\n",
       "      <td>exact_score</td>\n",
       "      <td>0.216215</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>-2.0</td>\n",
       "      <td>500</td>\n",
       "      <td>euler</td>\n",
       "      <td>exact_score</td>\n",
       "      <td>0.228516</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>-1.0</td>\n",
       "      <td>500</td>\n",
       "      <td>euler</td>\n",
       "      <td>exact_score</td>\n",
       "      <td>0.230749</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>0.0</td>\n",
       "      <td>500</td>\n",
       "      <td>euler</td>\n",
       "      <td>exact_score</td>\n",
       "      <td>0.232538</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11</th>\n",
       "      <td>1.0</td>\n",
       "      <td>500</td>\n",
       "      <td>euler</td>\n",
       "      <td>exact_score</td>\n",
       "      <td>0.227794</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12</th>\n",
       "      <td>2.0</td>\n",
       "      <td>500</td>\n",
       "      <td>euler</td>\n",
       "      <td>exact_score</td>\n",
       "      <td>0.230933</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>13</th>\n",
       "      <td>3.0</td>\n",
       "      <td>500</td>\n",
       "      <td>euler</td>\n",
       "      <td>exact_score</td>\n",
       "      <td>0.230740</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>14</th>\n",
       "      <td>4.0</td>\n",
       "      <td>500</td>\n",
       "      <td>euler</td>\n",
       "      <td>exact_score</td>\n",
       "      <td>0.234117</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>15</th>\n",
       "      <td>5.0</td>\n",
       "      <td>500</td>\n",
       "      <td>euler</td>\n",
       "      <td>exact_score</td>\n",
       "      <td>0.228170</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>16</th>\n",
       "      <td>6.0</td>\n",
       "      <td>500</td>\n",
       "      <td>euler</td>\n",
       "      <td>exact_score</td>\n",
       "      <td>0.231055</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>17</th>\n",
       "      <td>7.0</td>\n",
       "      <td>500</td>\n",
       "      <td>euler</td>\n",
       "      <td>exact_score</td>\n",
       "      <td>0.229593</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>18</th>\n",
       "      <td>8.0</td>\n",
       "      <td>500</td>\n",
       "      <td>euler</td>\n",
       "      <td>exact_score</td>\n",
       "      <td>0.222680</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>19</th>\n",
       "      <td>9.0</td>\n",
       "      <td>500</td>\n",
       "      <td>euler</td>\n",
       "      <td>exact_score</td>\n",
       "      <td>0.223080</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>20</th>\n",
       "      <td>10.0</td>\n",
       "      <td>500</td>\n",
       "      <td>euler</td>\n",
       "      <td>exact_score</td>\n",
       "      <td>0.220014</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "       a  num_steps scheme         loss        w2\n",
       "0  -10.0        500  euler  exact_score  0.230114\n",
       "1   -9.0        500  euler  exact_score  0.230397\n",
       "2   -8.0        500  euler  exact_score  0.227390\n",
       "3   -7.0        500  euler  exact_score  0.222234\n",
       "4   -6.0        500  euler  exact_score  0.225810\n",
       "5   -5.0        500  euler  exact_score  0.238143\n",
       "6   -4.0        500  euler  exact_score  0.238609\n",
       "7   -3.0        500  euler  exact_score  0.216215\n",
       "8   -2.0        500  euler  exact_score  0.228516\n",
       "9   -1.0        500  euler  exact_score  0.230749\n",
       "10   0.0        500  euler  exact_score  0.232538\n",
       "11   1.0        500  euler  exact_score  0.227794\n",
       "12   2.0        500  euler  exact_score  0.230933\n",
       "13   3.0        500  euler  exact_score  0.230740\n",
       "14   4.0        500  euler  exact_score  0.234117\n",
       "15   5.0        500  euler  exact_score  0.228170\n",
       "16   6.0        500  euler  exact_score  0.231055\n",
       "17   7.0        500  euler  exact_score  0.229593\n",
       "18   8.0        500  euler  exact_score  0.222680\n",
       "19   9.0        500  euler  exact_score  0.223080\n",
       "20  10.0        500  euler  exact_score  0.220014"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "distances_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "ef39f951-3cd0-4c90-aa65-bbe8db8edaa3",
   "metadata": {},
   "outputs": [],
   "source": [
    "#save\n",
    "#file_path = os.path.join('models/gaussian_aniso/d50_explicit_rescale_ep150_pkl', 'distances_df.pkl')\n",
    "#distances_df.to_pickle(file_path)\n",
    "\n",
    "#load \n",
    "file_path_load = 'models/gaussian_aniso/d50_explicit_pkl/distances_df.pkl'\n",
    "distances_df = pd.read_pickle(file_path_load)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "62e7f941-5100-4f84-b9f7-3292a8ebf546",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAxAAAAGACAYAAAA9AISXAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8g+/7EAAAACXBIWXMAAA9hAAAPYQGoP6dpAACGRElEQVR4nO3dd3hTZfsH8G9mk+5NC6Vggba0pZTRQtlLEF9cuAX9iSDIkFdUlqDiQF4ZMmUpiCxxoyIqMhSQIVB2KbOMQhc0bbrTJuf3R5pAaEuTNm2S9vu5rl7ak3NOnjxJD+fO89z3IxIEQQAREREREZEZxLZuABEREREROQ4GEEREREREZDYGEEREREREZDYGEEREREREZDYGEEREREREZDYGEEREREREZDYGEEREREREZDYGEEREREREZDYGEFRnuGah9bFPiYiIqK4xgLBT//3vf9GpU6dy20+ePImwsDC0b98eJSUlJo+dOnUKYWFh2Lx5MwAgLy8PH3/8Mfr164eYmBg89NBD2LBhA3Q6ncXtef755/H8888bf+/Tpw+mTJli9vE7duzA5MmTLX5eqlxD6tPFixcjLCzsnj/FxcUWn89e3f33RkREZE+ktm4AVSw+Ph6///47Ll26hJCQEOP2PXv2wNPTE9nZ2Th69Cji4uKMjx0+fBgA0LVrVwiCgNdeew0nT57E+PHjERISgv379+PDDz9EdnY2xo4dW6P2LVmyBK6urmbvv2bNmho9H5XXEPv066+/rvQxuVxehy0hIiJquBhA2Kn4+HgAQEJCgkkAsXfvXjzwwAPYvXs39uzZYxJAHDp0CKGhofDz88Pp06exZ88eLFiwAAMHDjSeMycnB59//jnGjBkDkUhU7fZFRERU+1ii6oqJibF1E4iIiBo8TmGyU82aNUOTJk2QkJBg3Jabm4vjx4+jS5cuiI+Px969e02OOXLkCLp27Wr8/emnnzYGIgYhISEoKCjArVu3Kn3uGzduYNy4cejQoQO6du2KL774otw+d09h2rJlCx5++GFER0ejc+fOePPNN5Geng5APx3j33//xb///ouwsDAcPHgQAJCUlIRx48ahc+fOiIyMRPfu3fHhhx+iqKjIeN6wsDBs2LAB06ZNQ1xcHNq1a4f//ve/uHnzpkl7Nm/ejMceewxt27ZFr169MG/ePGg0GuPj586dw6hRo9C+fXu0b98eY8eOxbVr1yrtg8rodDqsXLkS999/P6KiojBgwACsW7fO+PipU6cQGRlp0je3bt1CfHw8hg0bZsxZOHToEIYPH47Y2FhERUWhT58+WLx4scn0sry8PHzwwQfo3r07YmJi8Pjjj+Ovv/66Z5/eLSsrC2+88Qa6du2KNm3a4JFHHjFOcTO4dOkSxo0bh7i4OMTGxmLUqFG4ePGi8fHc3FzMmjUL/fr1Q5s2bTBo0CB89913Jufo06cPPvroI/zf//0foqOjMW3aNABAdnY23nnnHXTp0gVt2rTBU089hf3791vc75aoaPrPwYMH79lPALB9+3YMHjwYbdq0QdeuXfHhhx+ioKDA+PjixYtx//33Y8mSJYiLi0O3bt2Qk5NT4bm+/PJLPPDAA2jTpg26d++OGTNmIC8vz/i4RqPBggUL0LdvX0RHR2PQoEH48ccfTc4hCAI+++wz9OrVC9HR0Xj66adx4sQJk32q+lwbXvf+/fvx/PPPIzo6Gr169cK3336LjIwMjBs3Du3atUPPnj3LjWjZ4r0jIiIHIZDdmjp1qjBgwADj73/88YfQunVrIScnR/j111+FsLAwITMzUxAEQTh//rwQGhoq7N69+57nHDp0qNC5c2dBq9VW+Hh+fr7Qu3dv4f777xd+/fVX4bfffhMGDhwoREZGCkOHDjXu17t3b2Hy5MmCIAjC4cOHhdatWwuLFy8WDhw4IGzevFno2rWrMGTIEGPbHn30UeHRRx8Vjh49KuTm5grp6elC+/bthZdeeknYtWuX8M8//wizZs0SQkNDhRUrVhifJzQ0VOjQoYMwZcoUYc+ePcLGjRuFNm3aCBMmTDDus379eiE0NFSYNm2asHv3bmHDhg1C27ZthbffflsQBEG4dOmS0K5dO+Hxxx8Xtm3bJmzdulV46KGHhK5duwo3b9605C0R3n77bSEyMlJYtGiRsGfPHuGTTz4RwsPDhSVLlhj3mT9/vhAaGirs27dPEARBGDNmjBAXFyekpaUJgiAIZ86cESIiIoTXX39d2LNnj7B7925h4sSJQmhoqLBlyxZBEAShtLRUePLJJ4XY2Fhh3bp1wj///CO8+eabQkREhHDo0KEK+7QiL730kvDII48If/75p7B//35hypQpQmhoqLB//35BEAQhLS1N6Nixo/Cf//xH+PXXX4Vdu3YJgwcPFrp27SqoVCqhsLBQGDRokBAfHy989dVXwu7du4V33nlHCA0NFZYtW2byeYiIiBDmzJkj7NmzR0hISBCKioqEhx9+WOjSpYvwzTffCH/99Zfw6quvChEREca+MdeiRYuE0NBQoaSkpMKfOz/PQ4cONfmsCoIgHDhwQAgNDRUOHDhgcj6Dn3/+WQgNDRXeeOMN4e+//xY2btwoxMbGCv/3f/8n6HQ64zERERHCE088Iezdu9f4Xt3tl19+ESIjI4W1a9cKBw8eFL766ishJiZGmDRpknGfV199VYiOjhaWLVsm7Nu3z/jZ/+WXX4yvITw8XHjyySeF7du3C7///rvQq1cvoUuXLkJJSYkgCOZ9rg2vu3PnzsLq1auFffv2CS+++KLQunVrYcCAAcKCBQuEffv2CePGjRNCQ0OF48ePC4IgWPW9IyKi+ocBhB0z3NTcunVLEARBmD59uvD0008LgiAIKpVKCA8PF3788UdBEARh48aNQlRUlFBYWFjp+dasWSOEhoYKq1evrnSf9evXC2FhYcL58+eN227cuHHPAGLFihVCu3bthOLiYuPjf/31l7B48WLjzdfdN3V79uwRhgwZUu7Gd9CgQcJLL71k/D00NFR49tlnTfaZMmWKEBMTIwiCIGi1WiE+Pl4YM2aMyT6ff/658NhjjwkajUZ4/fXXhS5dupg8l0qlEjp06CD873//q7Qv7nbp0iUhLCzMJMARBH3A0KZNGyErK0sQBEHQaDTCQw89JAwYMED4/vvvhdDQUOG3334z7v/jjz8KI0aMMLnp1Wq1QocOHYxBz86dO4XQ0FDhzz//NNnn6aefFhYvXiwIQsU3yneLiooyudHXarXC//73P+HIkSOCIAjC//73PyE6OlrIyMgw7pOamir06tVL+Ouvv4QNGzYIoaGhQkJCgsl533rrLaFNmzaCSqUSBEH/eejXr5/JPl9//bUQGhoqHDt2zLhNp9MJQ4YMEQYPHnzPdt/NcMNf2c97771n3NfSAEKn0wk9evQQhg8fbnLMvn37hNDQUGHXrl0mxxw6dOiebX377beFAQMGmLy/P/30k7B27VpBEATh7NmzQmhoqLBmzRqT48aNGydMnz7d+Bqio6ON/SsIgvDNN98IoaGhwpkzZwRBEMz6XBte95w5c4z7HDt2TAgNDRUmTpxo3JaVlSWEhoYKX3zxhSAI1n3viIio/mEOhB0zTD86evQo+vbti7179+Lxxx8HAHh6eiIyMhL79u3Do48+isOHD6N9+/ZQKBQVnmv9+vWYNWsWBg4ciBdffLHS5zx8+DCCg4PRsmVL47bAwMB7zj2PjY3F/PnzMWjQIAwYMAA9e/ZEt27d0LNnz0qP6datG7p164aSkhJcuHABV65cwblz55CVlQVPT0+Tfe9+7oCAABQWFgIAkpOTcevWLdx///0m+wwfPhzDhw8HABw4cABxcXFQKBQoLS0FALi6uqJjx47Yt29fpW2824EDByAIAvr06WM8D6CfvrNs2TIcOXIE/fr1g0wmw8cff4wnn3wS06ZNw2OPPYYHHnjAuP+jjz6KRx99FMXFxUhOTsaVK1dw5swZaLVaY2WtI0eOQCaToU+fPsbjxGIxNm3aZHZ7AaBTp05YvHgxEhMT0b17d/Ts2dOkctORI0cQExMDPz8/47aAgADs2rULAPDaa6+hSZMmaNeuncl5H374YXz33Xc4fvy48X1u3bq1yT779++Hn58fIiMjTfqrd+/emD17NnJycuDh4WHR67l76pSBj4+PRee506VLl5CWloZRo0aZtDM2Nhaurq74559/0KtXL+P2u1/n3Tp37oyvv/4agwcPRr9+/dCzZ0889NBDxpyjI0eOAAD69+9vctzixYtNfm/ZsqXJ30JQUBAA/ZQywLLP9Z3vn6Gv2rZta9zm5eVlcu7aeO+IiKj+YABhx3x9fREaGoqEhAQ0b94cN27cQPfu3Y2Pd+3a1Tif/ciRI3juuefKnUOn02H27Nn44osvMGjQIHz88cf3TJ7Oyckx3kzcyc/Pr1zegUG7du2wcuVKrFmzBl988QVWrlwJX19fvPLKK5WWotTpdPjkk0+wYcMGFBQUIDAwENHR0XByciq3r1KpNPldLBYbcwmys7MB3PsGMjs7G1u3bsXWrVvLPebt7V3pcRWdBwD+85//VPi4IecD0N9khoWF4dSpU+jdu7fJfkVFRfjggw/w008/obS0FEFBQWjXrh2kUqnJ6/L09IRYXLM0pfnz52P58uX47bff8Mcff0AsFqNLly54//330aRJE2RnZxtvTCuSk5NjElwY+Pr6AgDUarVxm7Ozs8k+2dnZyMzMRGRkZIXnzszMtPgmtE2bNhbtbw7D+/ree+/hvffeK/d4RkaGye8uLi73PN+DDz4InU6HjRs3YunSpVi8eDGaNGmCN998Ew8++KBZn1mgfH8aPguGPBlLPtcVVUy7++/qTrXx3hERUf3BAMLOde7cGcePH0dgYCA8PT1NbqC6deuG5cuX48CBA0hNTTVJoAb0iZpvvPEGtm3bhpdeegmTJk2qsvKSl5cXrly5Um674aanMt27d0f37t1RWFiIAwcOYO3atfjwww/Rtm1bREdHl9vfEHC899576N+/P9zc3AAATzzxxD2f527u7u4A9MnCd1KpVEhMTES7du3g5uaGLl26YNiwYeWOl0rN/xMwPNeXX35Z4U1k48aNjf//9ddf49SpUwgPD8fMmTMRHx9vPH7mzJn4448/sGDBAnTp0sV4o3hnwrubmxuys7MhCILJe5aYmAhBECq9sbubm5sbJk6ciIkTJ+LSpUvYsWMHli5divfeew8rV66Em5tbub4D9N9ABwUFwcPDo8LPQ2ZmJgBUGGze+dzNmzfH3LlzK3z8XoFLTWm1WpPf70yGvpvhfZk0aZJJVTOD6twoDxo0CIMGDUJubi727t2Lzz77DBMnTkSHDh1MPrMBAQHGYy5evIjs7Gx06NDBrOew1ue6snPb6r0jIiL7xypMdq5Lly44ffo0Dh48iPj4eJNvpGNiYuDi4oKNGzfCy8urXGnVqVOn4s8//8TUqVMxefJks8q2du7cGSkpKTh58qRxW1ZWFo4dO1bpMR9//DEef/xxCIIApVKJ3r17G6fJ3LhxAwDKfZN+5MgRtGzZEo8//rgxeEhPT8e5c+csWuguJCQEXl5exik3Bj/99BNGjhyJkpISxMXF4cKFC2jdujXatGmDNm3aICoqCmvWrMGff/5p9nN17NgRgD44MZynTZs2yMrKwsKFC41B1vXr1/Hxxx/jiSeewPLly5Gbm4uZM2eavPZOnTqhX79+xuDh1KlTyMrKMr72jh07oqSkBLt37zYeJwgCpk6dihUrVgAo36d3u379Onr27Inff//d2Fcvv/wyunTpYnxfOnbsiOPHj5sEEbdu3cKIESPw999/IzY2FtevX8fRo0dNzv3zzz9DJpNVGBwaxMXFITU1FT4+Pib99c8//+Dzzz+HRCK5Z/ury9XVFWlpaSbbDNOGKhISEgIfHx+kpKSYtLNRo0aYN28eEhMTLXr+1157zbjOipubGwYOHIgxY8agtLQUGRkZxgBh586dJsfNnTvX5HNSFWt9ris7ty3eOyIicgwcgbBzsbGx0Gg02LVrF2bMmGHymEwmQ1xcHHbu3In+/fubBAjbt2/Hli1b0KdPH8TExJQLACIiIipceOuRRx7B2rVrMW7cOEyYMAGurq5YtmzZPW/qO3fujC+++AJTpkzBww8/jJKSEnz++efw9PRE586dAei/5T169Cj279+PiIgIREdHY+nSpVi5ciViYmJw5coVrFixAhqNxpjfYA6JRIJXX30V77//Pnx8fNCnTx8kJydj0aJFGDJkCDw8PDBmzBg888wzGDVqFJ599lk4OTnh66+/xvbt27Fo0SLjuRITEyGXy03yP+4UFhaGhx9+GG+//TauX7+OqKgoJCcnY/78+QgKCkLz5s0hCAKmTZsGpVKJSZMmwcPDA6+99ho++ugjDBgwAH369EF0dDR+++03fPXVV2jRogWSkpKwbNkyiEQi42vv1asX2rVrhylTpuC1115D06ZN8dNPP+HixYv44IMPKuzTu78pb9KkCQICAvDhhx8iLy8PwcHBOHXqFP7++2+MGjUKAPDiiy9i8+bNGDFiBEaNGgWZTIZly5YhICAADz30EORyOTZu3IixY8di/PjxCAoKws6dO/H9999j3Lhxxm/TKzJ48GCsX78ew4YNwyuvvILAwEDs27cPn332GYYOHQqZTAYAuHr1KrKyssxa4+Fegex9990HDw8P9O7dGzt37sSsWbPQp08fHD58uFzp2jtJJBJMmDAB77zzDiQSCXr37g21Wo2lS5ciPT3d7NEeg86dO+Pdd9/Fxx9/jB49ekCtVmPJkiVo3rw5wsPDIZPJ8MADD2DOnDkoKipC69atsXv3buzatQtLliwx+3nM/VxXh7nvHRERNUwMIOycq6sr2rRpg6NHj6Jbt27lHu/evTt27dqFLl26mGzftm0bAP23nHd/0wkAO3bsqHAaglwux5dffomPPvoIM2fOhEgkwlNPPYWmTZtWunZEz549MXfuXKxevRrjxo2DSCRChw4dsHbtWmMS6JAhQ3Dq1Cm8/PLLmDVrFkaNGgWVSoW1a9fi008/RWBgIB555BGIRCKsWLECarX6njendxoyZAicnZ2xatUqfP311wgICMDLL7+Ml19+GQAQHh6ODRs2YP78+Zg0aRIEQUBoaCg+/fRT9O3b13iecePGoUmTJibrOtxt1qxZWLFiBTZt2oS0tDT4+PjgwQcfxGuvvQaJRIINGzZg//79WLBggfGG/vnnn8cvv/yCd955B+3bt8eUKVNQUlKCBQsWQKPRICgoCKNHj8aFCxewc+dOaLVaSCQSfPbZZ5g7dy4WLlyIwsJChIWFYfXq1cZv/e/u04ceeqhce5csWYJPPvkECxcuhEqlQmBgIMaNG4eRI0cC0CfIb9y4EXPmzMGUKVMgl8vRqVMnzJ8/39j+devWYd68eVi4cCHy8vIQEhKCmTNnVjndzNnZGRs2bMC8efMwZ84c5ObmokmTJnjjjTfw0ksvGfdbunQpfvzxR5w9e7bK9/rpp5+u9LFPP/0U/fr1w+OPP46rV6/ixx9/xKZNmxAbG4tFixbh2WefrfTYJ598Ei4uLvj888/x9ddfw9nZGe3bt8fcuXPRtGnTKtt1p2eeeQYlJSXYtGkTNm7cCIVCgfj4eEycONF44z1nzhwsWbIEX375JVQqFVq0aIFFixahX79+Zj+PuZ/r6jD3vSMiooZJJBiyNokauGvXrmHGjBlYtWqVrZvS4PTv398Y9BIREZF9Yw4EUZnly5eXS0Sn2vfTTz8hJCTE1s0gIiIiM3EEgqhMYmIiWrdubVayOVnPpUuXEBAQUK5sKREREdknBhBERERERGQ2TmEiIiIiIiKzMYAgIiIiIiKzMYAgIiIiIiKzMYAgIiIiIiKzcSE5KxEEATpd9fLRxWJRtY9tqNhnlmF/WY59Zrnq9plYLLJa9TNei+sW+8xy7DPLsL8sVxfXYgYQVqLTCcjKyrf4OKlUDC8vF6jVBSgt1dVCy+of9pll2F+WY59ZriZ95u3tAonEOgEEr8V1h31mOfaZZdhflqurazGnMBERERERkdkYQBARERERkdkYQBARERERkdkYQBARERERkdkYQBARERERkdlYhYnIgelLVuqg02kr3UenE6GoSAKNphhaLUvhmYN9ZrnK+kwikUIs5ndVRET1CQMIIgckCAIKC/OQl5dzz+DB4OZNMXQ6lsCzBPvMcpX1mVLpCnd3b6ut9UBERLbFAILIAanVWSgszINC4QKFwhliseSeN2cSiYjfpFuIfWa5u/tMEARoNMXIy1MBADw8fGzVNCIisiIGEEQORqfTorAwH66unnB19TDrGKlUzEV4LMQ+s1xFfSaXOwEA8vJUcHPz4nQmIqJ6gFdyIgej1WoBCHByUti6KURmMQQRWm2pjVtCRETWwACCGhRVbjFy8jW2boaVcD45OQbmPhAR1S+cwkT1XmFxKQ6fzcD+U2k4ezUbSicpZo/uAmcFP/5EREREluIdFNVLWp0OiZdV2H8qDQnnMqG5Y152QXEpbuYUIljhZsMWkiO7dOki0tJS0aVLN1s3xWLPPPMYUlKumWwbOHAQpk2bAQDIycnGggVzsX//PxCJROjXbwDGjv0vFIrbU+Z27tyO1atX4MaNG2jWrBnGjn0NHTvG1eXLICIiG2IAQfXK1fRc7DuVhoOJ6SZTlQJ9nBEfGYC/jl1HlroYhcWci03VN3nyBDzwwH8cLoAoLCzEjRvXMXv2AoSFhRu335lPM336ZBQVFWLhwmXIy8vFrFnvo7CwANOnvwcASEg4jPffn46xY19DXFxnbNnyEyZNeg2rV29Ay5Yt6vw1ERFR3WMAQQ5PlVuMg4np2HcqFSmZ+cbtrkoZOkU0QpeoADQPcINIJMLR85n6AEJT9doJRJURBMcs75qcfBE6nQ5RUdFwd3cv9/ipUydw9OgRrF//LZo3vw8AMGnSNLzxxqsYNWos/Pz8sX79GvTo0RtPPvkMAGDs2P/i5Mnj+OabjXjrrbfr9PUQEZFtMIAgh1RUXIp9J1Ox50QqEi9nwXA/J5WIENPSF/FRAWgT4gOpxLROgEKu/8hzBMI+dOvWEVOmvI0///wDJ08eh5ubKx599AkMG/YyAGDVqhX47bct+O67X4zH3L2tW7eOmDjxLfzxx1YkJSUiMLAxpkx5G5cuXcSXX65CXl4eOnfugmnT3jW7clVi4il8+ukCnDt3FhKJFB06dMSrr76BgIAAPPHEQ0hLS8UXX3yGo0ePYMmSlcjLy8Onny7Enj27UFJSgrCw1hgzZjzCwyOMbT58+F906hSPb7/9ClqtFj169MZ///sGXFxcAQD79/+Dzz9fjsuXL0GpdEZ8fFe8+urrFd7op6bewJNPPlxp+7/99mcEBjYut/3ixQvw9vap8JwAcPz4Ufj4+BqDBwBo164DRCIRTpw4ht69++HkyeMYN26CyXHt23fE33/vrLpjiYioXmAAQQ6lpFSH9dvOYv/pNBQW3x5FaBnkgS5RAYgN94eLQlbp8c5O+o98UT0NIARBgKak/NoFWp1Q62sayGXialXbWbJkASZMmIjJk6dh+/Y/sHLlUrRr1wExMe3NPsdnny3FlCnvIDg4GB9+OAOTJk1AeHhrzJ27EFevXsF7703HL7+0xRNPPFPlubRaLSZNmoBHHx2MadPeQ25uLubM+QizZr2PhQuX4rPP1mL48KHo0+d+vPDCMAiCgIkTx0MuV+DjjxfA1dUVv//+K0aPHo4VK75AaKh+qlBSUiIA4JNPPkV+fh7+978P8M47b2HevEXIzs7GtGkTMW7cBHTp0g0ZGen44IN3sXTpQkyZUv5bfX//Rvjpp98rfQ2enl4Vbr948QKUSiWmT5+EkyePw9PTC//5z8N44olnIBaLkZmZAX//RibHyGQyuLt7ID09HXl5uSgsLCy3j6+vHzIy0qvsWyIiqh8YQJBDOXctGzsTrgMA/D2ViI8KQHxkI/h7OZt1vMJJAkCfSF3fCIKAWesTcOF6jk2ev2WQB6YOaW9xEDFw4CAMGPAgAOCFF17Cxo3rcPLkcYsCiAcffBjduvUAAAwY8CDmz5+N11+fjKZNgxES0hIbNqzFpUsXzTpXfn4+cnKy4evrh4CAQDRu3ATvvfcRVCr9aspeXvrF0JRKJdzdPXD48L84deokfv11O9zd9Qv7jRo1FidPHse3324yJieLRCJ88MH/4OvrBwB4/fXJePPN8bh69TKKi4uh0WjQqFEAAgICERAQiI8//qRszY/yJBIJfHx8ze4fg+Tki8jNzUWvXn0xbNhInDhxDEuXLoJarcaIEa+gqKgIcrm83HFyuRwaTTGKioqMv5d/vL6URyYioqowgCCHkqXW38BEt/TFG0+3hVZr2Vx0pZNhClM9zYFwwHL7zZo1N/nd1dUVJSUlFp0jKKip8f+VSiUAoEmTIOM2Jycns8/p7u6O5557AfPmfYyVK5ehQ4dYxMd3RZ8+91e4/7lzSRAEAY8/Pshku0ajQXFxsfH3pk2DjcEDALRpEw1APyrQu3c/9Os3AJMnT4CPjy9iYzuhS5fu6NGjV4XPmZaWhueff7LS17Bu3bcICAgot33u3EUoLi6Gq6t+2lSLFi2Rn5+HL79chZdeGgknJ6cKAwGNRgOlUgknJyfj73c/fmeVJiIiqt8YQJBDUeXpb8gaeTuXfdNtYQBhyIHQ1L8RCJFIhKlD2lc4hUkqFdvtFKaKvvG+V5JyRd/KS6XlL2VicfXXyRw9+lU8+eRT2Lt3Dw4f/hfz58/Gxo1rsXr1hnLt1el0cHFxwapV68udRya7PZ1OIjFto1arK2unflRsxoyZeOmll3HgwD4cOnQQH3zwNqKjY7Bw4bJy5/X19cUXX2ystP2+vhWPTshkMpM2AUBISEsUFhYiN1cNf/9G2LPnb5PHS0pKoFbnwNfXH+7uHlAqlbh1K9Nkn5s3M+Hn519pe4iIqH7hStTkULLz9N98entU79vO2yMQ9S+AAPRBhJNcYpOf2lhtWCaToaCgwGTb3WsYWNvVq5cxd+4seHl549FHn8CHH87GvHmLcflyMi5cOAfAdGXlkJCWyM/PR0lJCYKCmhp/Nmz4Env33r4Zv3btKvLy8oy/nzp1AgAQFhaO06dPYdGieQgObo6nnnoOc+YsxNSp7+DIkUNQqbLKtVEqlZo8190/FQVUgiDgqacewerVK022nzlzGj4+PvDw8ETbtu2RkZFu0sdHjx4BAERHt4VIJEKbNjHGbQYJCYfRtm07s/uYiIgcGwMIcijZufoRCB/36gUQhtWni+rrFKZ6JioqGmp1DjZuXIfU1BvYvPl7HDiwr1af08PDE9u3/4GPP56Jy5eTcfXqFfz22xa4ubkbp1splUqkpFxDVtYtdOoUj1atQvHuu1ORkHAYKSnXsHjxJ9i69Rc0bx5iPG9hYQE+/PAdXLp0AYcOHcT8+bPRt+/9CAgIhIuLC3744VssXboIKSnXcOnSBezYsQ1BQcHw8PC0yusSiUTo0aM3vvpqPXbs+BPXr6fgp59+wMaNazF8+CsAgMjIKLRp0xbvvvsWzpw5jYSEw5gz5yM88MB/jCMMzzwzBNu3b8OmTetx5cplLF26EOfPn8VTTz1nlXYSEZH94xQmciiGKUw+HspqHa+Q198k6vqoffuOGD58FDZtWo9Vq5ajc+cuGD58JL79dlOtPaeHhyfmzl2EFSuWYNSoF6HVahEZGY0FC5YaS64+8cQz+PTTBWWlYr/C/PlLsXTpQrzzzhQUFhaiefMQzJw5Bx06xBrP6+/fCK1ahWHMmJchlUpw//0D8cor4wAAzZvfh5kz5+CLLz7Djz9+C7FYjPbtYzFv3qIaTcW62yuvjIOrqytWrFiCzMwMBAY2xn//+yYefvgxAPog46OP5mDevI8xfvwrcHJyQq9e/fDqq7fLtsbFdcbUqe9gzZrPsXLlMjRvfh9mz15QLpeFiIjqL5HgqCsi2RmtVoesrPyqd7yLVCqGl5cLVKr8Wp+jXh9MWLIXOXkazH+tJ3xcZRb32alLt/DJN8cR7O+KGS/F1VIra1dJiQa3bqXCxycQMln5/IGK1EUORH1jzT6raD2L+qiyPqvqM+vt7QKJxDqBEq/FdYd9Zjn2mWXYX5arSZ9Zci3mFCZyGFqdDup8fQ6ETzVzIBRlORAcgSAiIiKqHk5hIoehzi+BIABikQjurk5Q5xRUfdBd6nsSNd3bJ598jN9+23LPfT76aC5iYzvVUYuIiIgcDwMIchiqsgRqT1c5JOLqVfxRluVAFGm0EAShVioHkf0aNmxklcm+d67VYC3Dh4/C8OGjrH5eIiIiW2AAQQ4juyyB2tPNqdrnMIxAaHUCNKU6OMkkVmkbOQYvLy94eXnZuhlEREQOjTkQ5DAMIxBeNQggFHKJcbHmIk5jIiIiIrIYAwhyGIYRiJoEECKRqB4lUrOAGjkGFvsjIqpfGECQw8i2wggEADg73c6DcEQSiQSACMXFRbZuCpFZNBr9365EwlmzRET1Aa/m5DCMORCuNQsg9CMQxQ47AiEWS6BUuiAvLxulpSVQKJwhFkvumRCu04mg1fJbYEuwzyx3d58JggCNphh5eSoola5WXRSPiIhshwEEOQxVnn4NiJqOQBhLuRY5ZgABAO7u3pDJnJCXl42ioqoXzRKLxdDpuAiPJdhnlqusz5RKV7i7e9ugRUREVBsYQJDDsN4UprIAQuO4AYRIJIKzsyuUShfodDrodJVPx5JIRPDwcEZOTgG/UTcT+8xylfWZRCLlyAMRUT3DAIIcQnGJ1jjlqCZlXAF9JSYAKCx2zByIO4lEIkgkkrK8iIpJpWIoFAoUFmotXta+oWKfWY59RkTUcPBrIXIIhvwHuUxsHEGoLsPxLONKREREZDkGEOQQsnNvJ1DXdPXo+lPGlYiIiKjuMYAgh6AyrAFRwwpMwB1J1AwgiIiIiCzGAIIcQnauvgJTTfMfAEBpyIFw0HUgiIiIiGyJAQQ5hGyOQBARERHZBQYQ5BBuLyInr/G5mERNREREVH0MIMghqAxJ1FaYwsQkaiIiIqLqYwBBDuH2CETNAwjjCARzIIiIiIgsxgCC7J4gCFCVJVHXdBVqAFA46ZOoOQJBREREZDkGEGT38otKUarVr2xrjRwIQxJ1sUYLnU6o8fmIiIiIGhIGEGT3DIvIuSplkEklNT6fUn57JesiDUchiIiIiCzBAILsnjUrMAGATCqGVKL/6HMaExEREZFlGECQ3bNmBSYD57I8iKJiJlITERERWYIBBNk9a1ZgMmApVyIiIqLqYQBBdk+VV1aByYoBhNJYypUBBBEREZElbB5A6HQ6LFq0CN27d0dMTAxefvllXLt2rdL9z58/j5EjR6JTp06Ij4/H+PHjcePGDePjhYWF+OCDD9CtWze0bdsWQ4YMwbFjx0zOsWzZMoSFhZX7IfuUXQtTmJRylnIlIiIiqg6bBxBLly7Fxo0b8cEHH2DTpk3Q6XQYMWIENBpNuX1VKhWGDRsGhUKBdevW4bPPPkNWVhZGjBiB4mL9Teb06dOxd+9efPLJJ/j5558RGhqKYcOGIT093Xies2fP4pFHHsHevXtNfsg+qaycRA3cHoEoZA4EERERkUVsGkBoNBqsXr0a48ePR69evRAeHo758+cjLS0N27ZtK7f/9u3bUVBQgNmzZyM0NBRRUVGYM2cOLl68iISEBGi1WsjlcsyYMQNxcXFo1qwZXn/9dRQUFCAhIcF4nnPnziEiIgJ+fn4mP2SfDDkQ1lhEzsA4hYkjEEREREQWsWkAkZSUhPz8fMTHxxu3ubu7IyIiAocOHSq3f3x8PJYuXQqFQmHcJhbrX4JarYZEIsGsWbOM58vLy8PKlSvh4uKCmJgYAPqg5fLlywgJCanFV0bWotXpoM7Xj0ZZM4laySRqIiIiomqRVr1L7UlLSwMABAYGmmz39/c3PnanoKAgBAUFmWxbuXIlFAoFYmNjTbYvX74c8+fPh0gkwsyZM43PceHCBWi1Wvzxxx+YOXMmiouLERsbi4kTJ8Lf379Gr0cqtTwek5StR2D4L5lSqzUQBEAsEsHbXQGxWGSVPnNRygAAxSXaar1vjoSfMcuxzyxnT33Ga3HdYJ9Zjn1mGfaX5eqqz2waQBQWFgIA5HLTue1OTk7Iycmp8vh169Zh/fr1mD59Ory9vU0eGzhwIHr06IGtW7caH+/duzfOnTsHAFAqlVi4cCFu3bqFTz75BC+88AI2b95sMrphCbFYBC8vl2odCwDu7spqH1ufZeaWVWByd4KPj6vJYzXpMx9P/bFaATV63xwJP2OWY59ZztZ9xmtx3WOfWY59Zhn2l+Vqu89sGkAYbtY1Go3JjXtxcTGUyspfuCAIWLhwIZYtW4bRo0fj+eefL7dPs2bNAAARERE4c+YMvvjiC/Tu3RuPPvooevToYRJwtGrVCj169MDOnTvx4IMPVuu16HQC1OoCi4+TSMRwd1dCrS6EVqur1nPXZ1dvZAMAPFzkUKnyAVipz3T647Jzi4znra/4GbMc+8xyNekzd3el1b4t47W47rDPLMc+swz7y3J1dS22aQBhmFaUkZGB4OBg4/aMjIxKy6qWlJRg6tSp2LJlC6ZOnYoXX3zR+Fh+fj727NmDzp07w9PT07g9NDQUO3fuNP5+92iFv78/PD09K5w2ZYnS0up/uLVaXY2Or69u5RQB0AcQd/dPTfpMLtWXcS0sKm0w/c7PmOXYZ5azhz7jtbhusc8sxz6zDPvLcrXdZzadVBYeHg5XV1ccPHjQuE2tViMxMbFcToPBpEmT8Pvvv2PevHkmwQOgX1Pi9ddfx++//26y/cSJE2jZsiUAYP78+RgwYAAEQTA+npKSApVKZdyH7Icq1/oVmAAmURMRERFVl01HIORyOYYOHYq5c+fC29sbTZo0wZw5cxAQEID+/ftDq9UiKysLbm5uUCgU+OGHH7B161ZMmjQJcXFxyMzMNJ7Lzc0Nbm5ueOqpp7Bw4UIEBAQgODgYmzZtwvHjx7Fp0yYAwP33349Vq1ZhxowZePHFF3Hz5k189NFHaN++Pbp3726rrqBKZBvXgLBuAOFsLOPKdSCIiIiILGHTAAIAxo8fj9LSUkyfPh1FRUWIjY3FqlWrIJPJkJKSgr59+2LWrFkYPHgwtmzZAgCYPXs2Zs+ebXIewz5vvfUWPDw88N577+HmzZuIjIzEmjVrEBUVBQCIiorCZ599hoULF2Lw4MGQy+Xo27cvJk+eDJFIVOevn+4tu5ZGIBROXImaiIiIqDpsHkBIJBJMnDgREydOLPdYUFAQzp49a/x99erVVZ5PLpdjwoQJmDBhQqX7xMfHm6w9QfYrO8/6a0AAd4xAaBhAEBEREVmChXXJrhlyIDytPQIh1wcQpVoBJaWcxkRERERkLgYQZLeKS7TGKUZervIq9raMYQoTABQwD4KIiIjIbAwgyG4ZEqjlMrGxapK1iEUiKOT6IKKIeRD10rWMPBSXMDgkIiKyNgYQZLcMCdSerk61kuDOUq71179n0vHu6n/x9c4Ltm4KERFRvcMAguyWqmwEwsvKCdQGt0u5MoCob3YeSQEAnLx4y8YtISIiqn8YQJDdys4tq8Bk5QRqg9ulXDnNpT5JVxXgXEoOAOCWugg5+Robt4iIiKh+YQBBdiu7lkcglCzlWi/9czLN5PfkG2obtYSIiKh+YgBBduv2KtTWrcBkoJQzB6K+0ekE7DuVCgDwcNF/bi6l5tiySURERPUOAwiyW9m1tAaEgWEEopABRL1x5qoKWepiODtJ8WDnZgCASxyBICIisioGEGS3VHm3qzDVBqWToYwrcyDqi39O6EcfOkU0QliwJwAgOVUNnSDYsFVERET1CwMIskuCICA7r3aTqFnGtX4pKCrBkXOZAIBu0YFo4ucCuUyMwmIt0m4V2Lh1RERE9QcDCLJL+UWlKCnVAbD+KtQGTKKuX/5NykBJqQ6NfV3QPMANErEYzRu5AeA0JiIiImtiAEF2yZBA7aKQQiaV1MpzMIm6fjFMX+rWJtC48GBIYw8AwKVUBhBERETWwgCC7JIhgdqrlqYvAUyirk9u3MzHxRtqiEUixEc2Mm4PaewOALh0g5WYiIiIrIUBBNml2k6gBphEXZ/8U1a6tU2INzzu+MwYAoiUjHxoSvg+ExERWQMDCLJLtV3CFWASdX2h1emw75R+8bhu0YEmj3m5OcHDRQ6dIOBKeq4tmkdERFTvMIAgu2SswFSrIxBMoq4PTidnISdPA1elDG1b+po8JhKJ7pjGxDwIIiIia2AAQXZJVYc5EEXFWq4T4MD2ntSPPnSOaASppPwljQEEERGRdTGAILuUbcyBqJ0SrgDgXJYDIQAo1nB+vCPKKyzBsfO3136oSEggAwgiIiJrYgBBdsmQRF2bIxBSiRgSsb7cJysxOaaDieko1QoI9ndFcNmaD3drHugOEYBb6iLk5GvqtoFERET1EAMIsjtanQ7q/NrPgRCJREykdnB7T+qrL3VtU/HoA6CfqtbY1wUAy7kSERFZAwMIsjvq/BIIAiAWieDuXHtTmICGUcpVU6LF6Uu3INSzPI+UjDxcScuFRCxC5zvWfqjIfcyDICIishoGEGR3DPkPHq5yiMumGNWW+j4CodMJmLfpGKZ8uhcHTqfbujlWZRh9iGnpC7cqAk0mUhMREVkPAwiyO8Y1IGpx+pKBUl6/S7n+dvAKzlxRAQCOnMu0cWusp1Srw/7T+upL95q+ZGBIpL6cpmbFLSIiohpiAEF2R1UHFZgM6vMIRHKqGpv3JBt/T0zOqjc3zycv3kJuQQncXeRo08K7yv2b+LlALhOjsFiLtFsFddBCIiKi+osBBNmd7DqowGRw51oQ9UmxRouVP5+GViegY7gflE5S5BWW4Go9WY3ZMH2pS2QAJOKqL2MSsRjNy6o0cRoTERFRzTCAILujqsspTGVJ1PVtBOKrHeeQriqEl5sTXvpPBKLLVmg+nZxl45bVnDpfgxMXbwEAurYJMPu4kMYeAIBLqQwgiIiIaoIBBNmd7LzaL+FqYBiBqE/rQBw5m4Hdx1MhAvDyoAi4KmWICfUDACReVtm2cVZw4HQatDoB9wW6oYmfq9nH3U6kZilXIiKimmAAQXbHkERdt1OY6kcAocotxprfkgAAAzs3Q3gzLwAwBhDnU7JRXOK407UEQTBOX+pmRvL0nQwBREpGvkP3ARERka0xgCC7k80k6mrRCQI+35KI/KJSNAtww6Pd7zM+1sTPFd7uTijVCjifkm27RtbQ1fQ8pGTmQyoRIy7i3ms/3M3LzQkernLoBAFX0upHLggREZEtMIAgu6Ip0SK/SH8zXzcjEGULyWkc/xvpP/69ijNXVJDLxBj5UASkktt/3iKRCJH36asVJSY77jSmvSf0ow/tQ33hopBZdKxIJDKWc2UiNRERUfUxgCC7Yhh9kEvFxtGB2mRYB8LRRyCupOXih78vAQCe6xeKQB+XcvtE3ecDADh92TETqUtKdTiQqF/7wdLpSwbGPAgmUhMREVUbAwiyK8YKTG5OEIlqdxVqoH4kURdrtFhRVrK1fagfukdXfHNtGIG4lpGHnHxNXTbRKo5fuIn8olJ4uTkhonnVaz9UxDACkcwRCCIiompjAEF2pS4rMAH1I4n6653nkZZVAE9XOV4cGF5p4OXuIkewv75q0RkHHIUwrv0QFQCxuHrBZfNAd4gA3FIXOWQQRUREZA8YQJBdUdVhBSbgznUgHDMHIuFcJv46dgMiACPKSrbeS0TZKISjTWPKzivGyUuGtR+qN30J0AeMjX3107tYzpWIiKh6GECQXanLCkzA7RGIUq0OJaW6OnlOa7mzZOuATsFmTeuJLNsn8bIKgiDUavusaf+pNAgC0LKJBwK8nWt0rvsaM5GaiIioJhhAkF0xBBBedTWFSX47UbtQ4zjTmHSCgFW/JiKvsATBjVwxuEeIWce1CvKAVCKGKrcYaVkFtdxK6zBZ+6GS/A5LhDCAICIiqhEGEGRXsu9Ioq4LYrEITvKyUq4OlAfx56FrSLysglwqxqiHI01Ktt6LXCZBqyAPAMDpZMeYxnQpVY3UWwWQS8WIDfev8fkMidSX09TQOdAoDBERkb1gAEF2pa6TqAFAWRZAFDpIHsTV9Fx8//dFAMAzfVtVWLL1XozrQVx2jPUg/ilb+6FDmJ9VSvs28XOBXCZGYbEWqbccYxSGiIjInjCAILshCAJUeXU7AgE41mrUxSX6kq2lWgHtWvmiZ0xji89hyINIuqpCqda+8z40JVocPJMBoPprP9xNIhajeYBhGpP1E6lv3MzHrqPX7b5viYiIqqtaAURWVhbmzJmDxx57DN26dUNSUhKWLFmC7du3W7t91IAUFJcaE5m96iiJGnCsUq7f7LyA1FsF8KiiZOu9NG3kClelDEUaLZJreUE1QRCg01V/mlDC+UwUFpfCx12BsGZeVmtXba0HoRMEfPrjSaz74yy+2HrGoRLViYiIzGXxfIBr167h2WefRXFxMTp06ICkpCRotVokJydj6dKlWLp0KXr16lULTaX6zlDC1UUhhUwqqbPndZQRiGPnb2LX0esAgBH/iYCbc/WCLLFIhNbNvHAoKQOnk7PQKsjTiq28rVSrw8y1R5CSmQc/TyUCvJ31Pz7Oxv93c5bdMwgyTF/q2iYAYisuLFhbK1KfunTLOC1q/+l0+HgozU5wJyIichQWBxAff/wxfHx8sG7dOjg7OyMqKgoAMG/ePBQXF2P58uUMIKhajBWY6nD6EnA7B6JIY785ENl5xVi99QwAoH9sU2MeQ3VF3ueNQ0kZSLyswqPdrdHC8o6dv4kr6bkAgLSsggqrPjk7SdHojsAisOz//b2UyC0oMeZp1GTth4oYAoiUjHwUl2jhJLNOwLrt0DUAQLMAN1xJy8WWfZfh66FAj7aWTzUjIiKyVxYHEPv378dHH30Ed3d3aLWmN1xPP/00XnvtNWu1jRoYwwhEXSZQA/Y/AiEIAlZvPYO8whI09XfF4z1b1PicEc3104Eu3VCjoKgUzoqaJyff7a9j+tGSvu2DEBPqi7RbBcZAIu1WAbLURSgoLkVyqrrcVCoRAGeFFAKA8GBP+Hkqrdo2LzcneLjKkZOnwZW0XIQ29azxOa9l5CHxsgoiETD2sSjsOZ6KX/Zdxtrfz8LbzQlRIT41bzgREZEdqNZdg1Ra8WEajaZac7KJgDsqMNX1CISd50AcPX8Tpy5lQSoRY+TDkZBJa177wNdDiUZeSqSrCnH2qgrtQv2s0NLb0rMK9DfTAAbENYWvp9KYvG2gKdEiQ1WItKwCpJYFFekq/X8LikuRX6R/P3pUI1G8KiKRCCGB7jh6/iYu3VBbJYD4s2z0oUOYP3w9lHi0+324mVOE/afT8OnmU5g6pD2CG7nV+HmIiIhszeIAomPHjlixYgXi4+Ph5KS/0ROJRNDpdPjqq6/Qvn17qzeSGoZsG49AFNphAFFSqsM3Oy8AAB7o1BRNfC0r2XovEfd5I111HacvZ1k9gPj72A0AQJsWPvCtZPRALpMgyN8VQf6uJtsFQUBuQQnSsgpQqtWhtRWTp+8U0rgsgLBCHkROvgYHEtMA6KeYAfrr4rAHw5GdV4wzV1RY+N0JTHu+A7zdFTV+PiIiIluy+KvMN954AxcvXkT//v0xadIkiEQirFq1CoMHD8aRI0cwYcKE2mgnNQC3V6GuuwpMgH1PYdpxJAUZ2YXwcJXjwc7NrHruiGa1sx5ESanWuHJ0r5gmFh8vEong7iJHaFNPRDT3rrVRzZDG+gX1kq1QynVXQgpKtQJCGrujZRMP43apRIyxj0Whia8LVLnFWPDtcRQU2d/njIiIyBIWBxChoaH47rvv0KlTJxw8eBASiQT79u1DcHAwNm3ahNatW9dGO+kOWp0OV9JysTMhBV/vPI8sdZGtm2QVqjpehdrAXpOo1fka/LIvGQDwRM8WUMitm6fQupknRCJ9grM1P0OHz2Yir7AE3u5OiG5hv/P+mwe4QQTglroYOWXBa3WUlGqN1bEMow93clbI8N8no+HhIkdKZj6WbT7JNSKIiMihVeuO5L777sPs2bMhkZSt4FtYiNLSUri5WT6/V6fTYcmSJfj222+Rm5uL2NhYvPPOO2jatPw/xABw/vx5zJkzB8ePH4dYLEZsbCymTJmCxo0bG9syd+5c/PHHH8jNzUVUVBQmTpyImJgY4zlSUlLwwQcf4NChQ3B2dsYTTzyBV1991fh67I06X4OLN3Jw8boaF6/nIDlNDU3J7RuQmzlFGPtYGxu20DoMIxBMotb7cc8lFBZr0SzADfFRAVY/v7NChpBAd1y8ocbp5Cx0t1KloL/KbqZ7tG0Msdh+c6KUTlI09nPB9cx8XLqhrvY0rgOn05FbUAIfdyd0CKv4HL4eSrz2ZFv8b0MCTl9WYe3vZzHsweqt40FERGRrFo9AlJSU4N1338VTTz1l3Hb06FHEx8fj448/hk5n2TdrS5cuxcaNG/HBBx9g06ZN0Ol0GDFiBDQaTbl9VSoVhg0bBoVCgXXr1uGzzz5DVlYWRowYgeJi/c3n9OnTsXfvXnzyySf4+eefERoaimHDhiE9Pd3Y/uHDhwMANm3ahBkzZuCrr77Cp59+amlX1IpSrX50YceRFKz85TQmL9+H1xbvxeLvT2LrgSs4ey0bmhIdlE5S49zwo+duOvwohFanQ06+/j2v8zKudphEfTU9F7uP6/MInu3byqprINwpoiyx+fTlLKuc73pmHs6n5EAsEqF7tP2XLr0vsGbrQQiCYCzd2rdDU0jElV9SmwW4YfSjkRCJgL0n9RWaiIiIHJHFIxCLFy/Gzz//jFdffdW4LSIiAm+++SYWL14MLy8vjBw50qxzaTQarF69Gm+++aZx7Yj58+eje/fu2LZtGwYNGmSy//bt21FQUIDZs2dDodAnIs6ZMwe9evVCQkIC4uLiIJfLMWPGDMTFxQEAXn/9dWzcuBEJCQkYOHAg/vjjD9y4cQPffPMNPDw8EBoailu3bmH27Nl45ZVXIJfX7fx7TakWB06l4tjZdJy/loPLd40uGDT2dUGLxu5o0cQDLZp4INDHGWKRCLM3JiDpajZ2Hb1ulfKetqLOL4Eg6Bc5c6/mAmnVZW9J1IIgYNOO8xAEIDbc3yoVgioTeZ83ftl3GYmXVdAJQo0Dlb/KkqfbtfKt80CwOkIau2PviVRcquaK1ImXVbh+Mx9OMgl6tK16rYroFr54vn8Y1v5xFpv3JMPHXWH1NS6IiIhqm8UBxC+//ILJkyfjmWeeMW7z9PTEiy++CKlUirVr15odQCQlJSE/Px/x8fHGbe7u7oiIiMChQ4fKBRDx8fFYunSpMXgAAHHZN35qtRoSiQSzZs0yPpaXl4eVK1fCxcXFOIXp8OHDiIyMhIfH7UTHzp07Iy8vD2fOnEHbtm3N7wwrWPbjKRw5m2myzdlJipAm7mjR2AMtmrgjJNAdzgpZhcf37RCEpKvZ+PvYDTzctXmdruBsTYbpSx6u8jqf9qJ00vdZQbF95EAcO38TSVezIZWI8WSv2g0KQxq7w0kuQV5hCVIy8mpUZrRYo8W+U/pKRL3aWZ48bQshZSMQl9PU1QqgDKMP3aIDK/0bvVuvdk2QmVOI3w5cxZrfkuDt5oTWzWu2MCAREVFdsjiAUKlUleYnhISEIC0tzexzGfYNDDT9Bs7f37/C8wQFBSEoKMhk28qVK6FQKBAbG2uyffny5Zg/fz5EIhFmzpxpfI60tDQEBJjOJ/f39wcApKam1iiAkFajPn94My/kFJSgqZ+LvoJLkKdxdMEcHVv7w3uHE7LUxThyLhPdHGDaSEXUBbenL1XVjxKJ2OS/NeVWNuJRVFwKiURk03npJaU6fL1LX7b1wc7BCLBC2dZ79ZdUKkZ4sBeOX7iJM1dVCLmjgpCl/jmVisLiUvh7KdGmpU+tTbuypmaBbpDLxCgs1iIzuxBN/PQlZc35jF2/mY+Tl25BBOCBTsEW/f0/3bcVVOpiHEhMx6c/nsL0FzsiyM+16gPtmLX/LmuiOtdie2q/o2CfWY59Zhn2l+Xqqs8sDiBCQkLwxx9/oGvXruUe27lzJ5o1M7/UZGFhIQCUmzbk5OSEnJyqSyuuW7cO69evx/Tp0+HtbfoN3sCBA9GjRw9s3brV+Hjv3r1RVFQEd3f3cs8HwJhHUR1isQheXpbf7D0zoDWeGVCzylWDuoVg7dYz2Hn0Bgb1aOmQiZmasllb/t7OZveju7t1VidWuujffwGAwtnJ7G+Sa8MPuy4gQ1UILzcnDP1PpHF6lTVU1l9xkQE4fuEmzl7LwdAHqx+w7D6uL936n673wcfbcW6GWzX1wulLt5CWXYSo0EYmj93rM7Zxhz7Qi4sMQHgLyxOwJ74Qi7dX7ENichbmf3Mcc8f3qBdrRFjr77K6qnstNrB1+x0R+8xy7DPLsL8sV9t9ZvHdyQsvvIApU6YgOzsb/fr1g4+PD7KysrBr1y789ttvJlOIqmKYiqTRaEymJRUXF0OprPyFC4KAhQsXYtmyZRg9ejSef/75cvsYApmIiAicOXMGX3zxBXr37g2FQlEuQdsQODg7O5vd9rvpdALU6gKLj5NIxHB3V0KtLoS2mqUd48L88NUfZ3HhWjaOnE5Fixp8i2wr19P0c9BdnKRQqfLvua81+uxOQtnUFZ0gIDVdbbObOHW+Bpv+TAIAPNGrBYoKilFUUP2g1qCq/goJ0N/sn750C+mZasirMQ0uOVWN89eyIZWI0KGVb5XvoT1p1sgVpy/dwonzmejQyhdA1X2WW6DBzkNXAQB92zep9usd+1gUPlhzGGlZBXhnxT5Me6GD1cv11pWa/F26uyut9m2ZLa/FDQ37zHLsM8uwvyxXV9dii/+levTRR5Gfn4+lS5di27Ztxu1eXl54++238eijj5p9LsO0ooyMDAQHBxu3Z2RkICwsrMJjSkpKMHXqVGzZsgVTp07Fiy++aHwsPz8fe/bsQefOneHp6WncHhoaip07dwIAAgICcO7cOZNzZmRkAAAaNTL99tFSpaXV/3BrtbpqH+/sJEVca3/8cyoNf/x7FSMfiqx2O2wlS12WA+EiN7sfatJnd1M6SZBfVIrcgpI6T+I2+G7XBX3Z1kZu6BTRyGqvzaCy/vL3VMLTVY7sPA3OXFYhshrz8XccTgEAdAzzh7OT1Optr03Ny/I+LqbklGt3ZX22/XAKNKU6NGvkhhaN3av9epVyKV57qi1mrj2MK2m5WPL9Sbz6eJt7VnOyd9b8u6wuW12LGyr2meXYZ5Zhf1mutvusWv9KDRkyBHv37sXWrVuxceNGbNmyBf/88w+ee+45i84THh4OV1dXHDx40LhNrVYjMTGxXE6DwaRJk/D7779j3rx5JsEDoF9T4vXXX8fvv/9usv3EiRNo2bIlACA2NhaJiYnIy8szPn7gwAG4uLggPDzcovbbkz4d9Lkhh85kGMuhOhKVjdaAMLB1KddrGXn421C2tV/tlW2tiEgkMgYNicmWl3MtKCrFwUR9mWRHSZ6+U0hj/ZTGlMx8FJdUnUhfqtVhZ4I+YOof27TGUwb9PZX47xNtIZeKceLiLWzYdg6CINTonERERLWp2l9ziUQihISEoH379mjZsqWxGpIl5HI5hg4dirlz52LHjh1ISkrChAkTEBAQgP79+0Or1SIzMxNFRfo1Dn744Qds3boVEyZMQFxcHDIzM40/RUVFcHNzw1NPPYWFCxfir7/+wqVLl/DRRx/h+PHjGD16NACgX79+8PPzw2uvvYakpCRs374dn3zyCV566aU6L+FqTfcFuqNFY3dodQL+Pnbd1s2xmKEKk61Kf9qylKsgCPhq+7k6KdtaGcN6EImXVRYfeyAxDcUlWjT2dUGrIMebPuftroCnqxw6QcCVtNwq9//3TDpy8jTwcJUjtrW/VdoQ0tgdox6OhAj6UriGYJKIiMgeWXzXn5WVhTfeeAMdOnRAREQEWrdubfITERFh0fnGjx+PJ554AtOnT8ezzz4LiUSCVatWQSaTITU1Fd26dcPWrVsBAFu2bAEAzJ49G926dTP5Mezz1ltv4amnnsJ7772HRx55BCdOnMCaNWsQFRUFQJ8w/fnnn0On0xn3e+655zBmzBhLu8Lu9C0bhfjr6HWUOthcwexcwwiEbYI4W65GXZdlWysT0Vy/KOHV9FzkFpg/giUIgnHl6V4xjR0ygR+4Y0G5KtaDMFk4rn0QpFasctEu1A8Pd7sPgH4kkYiIyF5ZnAPx/vvvY9euXfjPf/6DgICAao083EkikWDixImYOHFiuceCgoJw9uxZ4++rV6+u8nxyuRwTJkzAhAkTKt2nWbNmZp3L0XQM98emnReQnadBwrlMxLWuWU5HXdGUaJFfpL9xt9kIhFyfOFykqdu1IEpKdfh6p76az4C4pvD1tE2lCQ9XJwT5uSAlMx9nrqjM/uxcvK5GSmY+5FIxukQFVH2AnQpp7I6j529WuSL1uWvZuJqeB7lUXCvTtaJb+OCnvcm4lpEHQRAcNiAjIqL6zeIAYvfu3Xjrrbfw9NNP10Z7qAakEjF6xTTGz/9cxo4jKQ4TQBimL8mlYquWLbWEUlE2AlFUtyMQO46kICO7EB4ucjzY2fwSyLUhork3UjLzcTo5y+zPzq6y0Ye4iEY2LX9bUyGN9VOvkm/cu3y0YfShS1QAXJXWf71NfF0gFomQV1iC7DyNQ6zmTUREDY/FwwcymazSheTI9nrGNIFELML5lBxcTa96Prc9yM7TT5nxdHOy2TeuyrLSmUWaugsg1Pka/LIvGQAwuGeIzYIng8j7DHkQWWYl8eYVluBQkn6qTW8HTJ6+U/MAN4gA3FIXIyev4tK56aoCHDt/EwBwf2ztXAPlMgkCfPTlpK9lOMbfLxERNTwWBxD333+/MReB7I+XmxM6hOkXtdpxJMXGrTGPKte2FZgA2+RAbN5zyVi2tWubwKoPqGWhQZ6QSkS4pS5Ghqqwyv3/OZmKUq2+lGnzALc6aGHtUTpJ0dhPv/hYZXkQ2w+lQIB+mlGgT81XCK9MU3/9uhzXMvKq2JOIiMg2LP7KMyIiAgsWLMC1a9fQtm1bkwXgAH11prFjx1qtgWS5vh2C8O+ZDBxITMeTvVvWylQLa7J1BSZAvw4EABQV100OhC3LtlbGSS5ByyYeSLqajdOXs9DIu/KFFU2Sp9s5bvL0nUIC3XE9Mx+XUtWIjTCdwlVQVIK9J/UrbdfW6INBU39XHExMZwBBRER2q1pJ1ABw6NAhHDp0qNzjDCBsr2UTDwT7u+JqRh72nLiBgZ1sO7e+KiobV2AC6raMqyAI2LTjPARBn/hui7KtlYlo7q0PIJKz0Kd9UKX7JV1RIV1VCIVcgk4RjpFrU5WQxu7YcyK1whGIv4/fQHGJFkF+Loho5lWr7eAIBBER2TuLA4ikpKTaaAdZkUgkQt8OQfjityTsSriOAbHBEIvt9xvibBsvIgfU7RSmY+dv4swVFaQSMZ6yUdnWykTe540fdl9C0lUVtDpdpSsi7zqmHz2JjwqAQm7b3A1rMSZSp6qh093OAdHqdMbpgPd3rPnCcVUJLgsg0rIKUFyihZNMUqvPR0REZCnrFTEvc+cKz2Q7nSIawUUhxc2cIhy/eNPWzbknwxoQNp3CVEdJ1PZStrUyzRq5wUUhRWGxFsmpFSfx5uQV4+i5TABArxjHTp6+U2NfZ8hlYhRptLhxK9+4/cjZTGSpi+HuLEPnyNofbfFwdYK7swyCAFzPzK/6ACIiojpm8VeHGo0GX375Jf79919oNBpjtRZBEFBQUIALFy7g+PHjVm8oWUYuk6BH28b47eBV7DiSgnat/GzdpEoZqzDZdARC/y1vQS3nQNhT2daKiMUitG7mhcNnM5GYnIWWTcqvLL3nRCq0OgEtm3gYp9vUBxKxGM0D3HHuWjYuXs9Bm9BGEAQBf/yrL93aq10TyKR1MxrQ1N8Vpy+rcC0jFyGN3evkOYmIiMxl8QjE7NmzMW/ePKSnp+PixYu4fv06CgsLceLECZw5cwajRo2qjXZSNfRu1wQiEZB4WYUbN+3zm0xBEKAyTGGyaRJ12QhELU5hUhfYV9nWykQ0v13O9W46nYC/y6Yv9WrXuE7bVRcMN+uGPIgLKTlITlVDKhGj9z1yQqytqb++qhXzIIiIyB5ZHEBs27YNw4YNw88//4yhQ4ciKioK3377LbZt24YmTZpAp9PVRjupGnw9lYhp6QsA2JlgnyVdC4pLUVKq/8x4utTvJOrNu/VlW4MbudpF2dbKRJStB3Hxhrpcf5xKvoVb6iK4KKToGOZvi+bVqpBAfQBx8bp+Qbnf/70KAOgc2Qgedfj5bNqIidRERGS/LA4gsrKy0KNHDwBAaGgoTp48CQBo1KgRRo4cia1bt1q3hVQjfTvovzX951RanVQYspQh/8FFIYXchsmihgBCU6pDqdb6QXB6VoGxbOtz/ULtomxrZfw9lfDzVECrE3D2WrbJY38d1b+Grm0Cbfp+1RbDCERKRj6upqlxuGyhvP4d63bxzDsrMenMWNSPiIioLlkcQLi5uUGj0c9Zb9asGVJTU42J082bN0dqaqp1W0g10rqZFwJ9nFGs0eKfk/b33tjD9CUAUMhv3wwXaayfB3HphhqCALQM8rCrsq2ViTRMY0q+PY3p1h0J+T1j6t/0JQDwdlfA01UOnSBg0TfHIAhARHMvBNVxrkeAtzOkEhGKNFrczCmq0+cmIiKqisUBRMeOHbFu3ToUFhaiWbNmUCqV2L59OwDg6NGjcHWtP0mV9YGhpCsA7Ei4bnffZmbn6oNRLxsmUAOAVCKGXKb/c6iNUq6GQMnfzqouVcaQB3H6jjyI3cdvQBCA8GDPWl2J2dYM5VzPXlEBAPrX8sJxFZFKxGjsq+/ja+mcxkRERPbF4gBi7NixOHbsGEaOHAmpVIrnnnsOb7/9NgYPHoyFCxdiwIABtdFOqoEuUQFQOkmQnlVQYWKsLansYA0IA2Mp11oIILJz7ed1miO8mRdEAFJvFUCVW4xSrQ67TxiSp+tP6daK3Fn1KNDHGVEhPjZpx+1pTBWX0yUiIrIVi8vAhIeH47fffsO5c+cAAG+88QZcXV2RkJCAPn36sAqTHVLIpegaFYjtR1Kw43AKou6zzQ1RRbLtZAoToM+DyMnX1EquiCFQsuVaF5ZwVcrQPNANyam5SLycBYVcgpw8DdydZWgfar8lga3BkEgNAAPigm2WrxLs74Z/kMZEaiIisjsWBxCHDh1CREQEunbtCkA/ReaVV14BAKjVamzbtg3/+c9/rNtKqrE+HYKw/UgKTly8hYzsQruZSmNcRM7VdhWYDAxrQRTWwloQt1fbtv3rNFdEc28kp+bi9OUs5Obrp5p1b9sYUonV15+0K/cFusPDVQ4nmQRdo21XLevORGoiIiJ7YvGdwAsvvICLFy9W+FhiYiKmTp1a40aR9QV4OyPqPm8IAHbZUUlXexuBAGqnlKtxCpMdvE5zGRKpj1+4idOXVRAB6NG2fiZP38lJLsH/RsVj4Ru94WTDSlOGUq43c4pQUGR/FdSIiKjhMmsEYvLkycbqSoIgYMaMGRUmS1++fBm+vr7WbSFZTd8OQTiVnIU9x1PxaLcQOMltX4ZTZUe5AYYAwtpJ1DpBMK62betkcUu0aOIBuUxsHJGJCvGBn52MXNU2F6UMrkoZVEUa27VBIYO3uxOy1MVIycxziOpdRETUMJg1AjFgwAAIggDhjgo+ht8NP2KxGDExMZg1a1atNZZqpk0LH/h5KlBQXIoDiWm2bg50OgE5ZVNj7CE3wJhErbFuAJFXUAKtToAIgLsNF8uzlEwqNrlprY8rT9u7YK5ITUREdsisEYg+ffqgT58+AIDnn38eM2bMQIsWLWq1YWR9YpEIfdoH4eudF7DjSAp6tG0MkQ0XNMvJ10AQ9O1yd7b9jXVtjUAYpmm5ucgdLn8gsrk3Tl3KgpebE6Jb2E/yfUMR5O+KYxdu4mo6KzEREZH9sPhuZt26deWCh1OnTmHbtm1Qq9VWaxjVjm7RgZDLxEjJzMe5u1YZrmuGG2sPVznEYtuvzGxIoi6ychK1ypgobvtRFkv1aNsYXaMCMGxgOCRixwp+6oNgJlITEZEdsviOICMjA88//zyWLl0KAFi/fj2efPJJjB8/Hv3798f58+et3kiyHheFDPGRAQCAHUdsm0x9e20E248+ALWXRO2IFZgMlE5SDB8UYbO1EBo6QyWm6zfzodXpbNwaIiIiPYsDiDlz5iA5ORlt2rSBTqfD8uXL0aVLF2zevBktW7bEvHnzaqOdZEV92+tXpk44dxNZ6iKbtSPbjhaRA2ovgDCOQNhBngc5Fj8vJZxkEpSU6pCeVWjr5hAREQGoRgCxd+9eTJ48Gd27d0dCQgJu3ryJF154AeHh4RgxYgQOHz5cG+0kKwryd0VYU0/oBAF/Hbtus3ao7KiEK1AXIxD28TrJcYhFIgT5uwAArnJFaiIishMWBxAFBQUICNBPgdm9ezfkcjk6d+4MAJDL5SaVmsh+9e2gH4X452Sazd6z7Fx9BSZ7ubE25EAUWDkHwlDC1V4CJXIsTVmJiYiI7IzFAUTz5s1x+PBhlJSU4I8//kBcXBycnPQ3Rj///DOaN29u7TZSLWjTwgcSsQiq3GJkZttmaoThm3l7SS6urTKunMJENcEVqYmIyN5YHEC8/PLLWLJkCeLj43Ht2jUMGzYMAPDEE0/g559/xvDhw63eSLI+J5kEIY3dAQBnr2bbpA23pzDZR3IxpzCRPWIAQURE9sasdSDuNGjQIAQGBuLIkSOIi4tDTEwMACA2Nhbjx49Hjx49rN1GqiVhwZ44n5KDpKvZ6N627hcJy7az8qa3AwgtBEGwyhoZJaU65BaUAOAIBFVPkJ8LRABy8jRQ52scajFCIiKqnywOIACgQ4cO6NChg8m2yZMnW6VBVHfCmnphC67g3DVVnT+3pkSL/CL9N/32khtgyIHQCQI0JTo4ySU1PmdOvj5IkkrEcFFU68+NGjiFXAp/LyXSVYW4lpGHyPu8bd0kIiJq4My6o5k6dSrGjBmDpk2bYurUqffcVyQS4aOPPrJK46h2tWziAYlYhFvqYtzMLoSvp7LOnjs7X59YLJeK4exkHzfWTjIJRCJAEPSrUVsjgLidKC636arf5Nia+rsygCAiIrth1p3bwYMH8X//93/G/78X3iQ5Die5BM0D3XDxuhpJV7PRrS4DiNzbeQH28pkRiURQyqUoKC4tS6Su+chItp2VqiXH1NTfFYfPZuIaS7kSEZEdMCuA2LlzZ4X/T44vrKkXLl5X4+w1FbpFB9bZ89rrjbXSSR9AFFgpkVplZ3ke5JhYypWIiOyJxVWYqH4JD/YEUPeVmFTGEQj7Sgg15EEUWWktCFZgImsIbqSvxJR6qwAlpTobt4aIiBo6s0YgXnjhBYtOunbt2mo1hupeiyYeEItEuJlThFs5RfDxUNTJ8xrXgLDDEQjAeqVcVXb6OsmxeLk5wUUhRX5RKW7czEezADdbN4mIiBows0YgBEEw+UlISMDRo0chCAJ8fX0hlUpx6tQpnDx5Er6+vrXdZrIipZPUeDNytg6rMaly7fObeWsHENl2OtJCjkUkEnE9CCIishtmjUCsW7fO+P9r1qxBVlYWVq1ahYCAAOP2rKwsjBw5EkFBQdZvJdWqsGBPJKeqcfZqNrpE1U0eRHaevjqRvX0zb/0RCPt8neR4gvxdkXQ1mwEEERHZnMU5EJ9//jn++9//mgQPAODt7Y1XXnkFX3/9tdUaR3XDmAdxLbvOnjPbzkcgrJFELQiC3b5Ocjy3RyBYiYmIiGzL4gCiqKgIgiBU+Fh+fn6NG0R1r1WQJ0QiIENVaJxaVJsEQbDfKkxlaz8UaWqeRF2k0aK4RH8eBhBUU8F3VGKq7BpMRERUFywOIDp37oxPPvkEly5dMtl++vRpLFiwAD179rRa46huKJ2kCG5UlgdxtfbzIAqKS6EpqyTj6WJfuQHWHIEwBGNKJ6lVFqWjhq2xrwskYhHyi0rrJNAnIiKqjMVLAE+bNg1DhgzBoEGD0LRpU3h5eeHWrVtISUlBq1at8NZbb9VGO6mWhQd74kpaLs5ey0bnyICqD6gBw7QeF4UUcpl93VgbAogiKwQQ9lppihyTTCpGgI8zrmfm42pGHrzd66ZiGhER0d0sHoEIDAzEr7/+imnTpiEyMhIuLi6Ijo7GBx98gO+//x6enp610EyqbWFNvQAASXWwHoQhgdrepi8Bt9eBsEYS9e1F5OxrlIUcFysxERGRPbB4BAIAlEolhgwZgiFDhli7PWQjoU09IAKQnlWA7LziWp2zb68lXAFAKS+rwmSFHAguIkfWFuzvhgOn03EtnYnURERkO1yJmgAAzgoZmpatdnuulqsxGaf22OGNtTXLuGbn2u9ICzkmjkAQEZE9YABBRoZpTGdreRrTjZv6al32mBtgzSRqjkCQtRkCiAxVIYqtMEpGRERUHQwgyCisbD2IpFqsxFRcosXRCzcBAG1CfGrtearLkANRVFzzmzMVk6jJytxd5PBwkUMAkJLJUQgiIrINBhBkFNrUEwCQeqsA6nxNrTzHsfM3UazRwtdDgRZN3GvlOWrCMAJRXKKFVqer0bk4AkG1gdOYiIjI1iwOILKysip9rKSkBDdu3KhRg8h2XJUyBPnVbh7E/tNpAIDOkQEQiUS18hw1YQgggJotJqcTBOSUVZviCARZkyFX6SoDCCIishGzAghBELBs2TJ06NABXbt2RdeuXfHll1+W2y8xMRF9+/a1eiOp7tTmNKbcAg1OJ+sD0PjIRlY/vzVIJWLIpPo/i8Ki6udB5OZroNUJEIkAdxeZtZpHdMcIBCsxERGRbZgVQHz11VdYvHgxHnroIUydOhUhISGYNWsW3njjDehqOM1Dp9Nh0aJF6N69O2JiYvDyyy/j2rVrle5//vx5jBw5Ep06dUJ8fDzGjx9vMupRVFSEefPmoU+fPmjXrh0GDx6MHTt2mJxj+vTpCAsLM/np06dPjV5HfRFWNo3pbC2MQBxKyoBWJ6BZIzcE+rhY/fzWYo1EasNaF+4uckjEnClI1tPUX79qfEpGPnSCYOPWEBFRQ2R2APHyyy9jxowZeOGFF7Bu3TpMnToVW7duxZQpU2rUgKVLl2Ljxo344IMPsGnTJuh0OowYMQIaTfk5+CqVCsOGDYNCocC6devw2WefISsrCyNGjEBxsX6++YcffohffvkF7777LjZv3ox+/fph3LhxOHjwoPE8Z8+exSuvvIK9e/caf7777rsavY76IrRsBOJ6Zj5yC6ybB3HgdDoA+x19MFDKyxKpazCFyZ7XuiDHFuCthFQiRnGJFpnZhbZuDhERNUBmBRApKSmIj4832fZ///d/mDZtGn7++WfMmTOnWk+u0WiwevVqjB8/Hr169UJ4eDjmz5+PtLQ0bNu2rdz+27dvR0FBAWbPno3Q0FBERUVhzpw5uHjxIhISElBYWIjNmzfj9ddfR8+ePdGsWTOMGTMGcXFx+P777wHop2NduHABUVFR8PPzM/54e3tX6zXUN+7OcjTx1Y8OWDMPIiO7EBeu50AkAuIi7DyAsMoIhP2udUGOTSIWo4mf/m/0WjrzIIiIqO6ZFUD4+voiOTm53PahQ4fixRdfxOrVq7Fu3TqLnzwpKQn5+fkmwYm7uzsiIiJw6NChcvvHx8dj6dKlUCgUt19A2fQQtVoNkUiE5cuXo0ePHibHicViqNVqAMDVq1dRUFCAkJAQi9vbUBhGIay5HsTBsuTp1s287P5beUMAUVSDAMI4AsEEaqoFwf5MpCYiItuRVr0L0K9fPyxatAg+Pj7o3Lkz3N1vl9+cPHkyrl+/jlmzZqF3794WPXlamv6mMjAw0GS7v7+/8bE7BQUFISgoyGTbypUroVAoEBsbC4VCgW7dupk8fuLECRw4cADTp08HAJw7dw4AsG7dOuzevRtisRg9evTAhAkT4ObmZlH77yaVWj7XXSIRm/zXHkQ098auhOs4l5Jdrdd0N0EQcCBRP32pa5vAGp+ztvvMWVFWyrVUV+22qsumf/m4O1mlD2vCHj9j9s7e+6xZoBv2nEjF9cw8m3++DOypz+rLtdjesc8sxz6zDPvLcnXVZ2YFEGPHjsX58+cxfvx4PP3003jvvfeMj4lEInzyySeYMmUKfv31V4tKcxYW6ufvyuVyk+1OTk7Iycmp8vh169Zh/fr1mD59eoVTkC5duoSxY8ciOjoaTz31FAB9ACEWi+Hv74/ly5fj6tWrmD17Ns6fP48vv/zSOKJhKbFYBC+v6icGu7srq32stXWKboxPfziJaxl5kDrJ4OYsr/qge7iQko3UWwWQS8Xo17k5nBXWqUpUW33m4VY2wiUWV/s9zS2r4NSkkXuNPhfWZE+fMUdhr30W2cIPwDmk3My3m8+Xga37rD5dix0F+8xy7DPLsL8sV9t9ZlYA4erqis8//xyJiYkVPi6TyTBv3jwMHDgQf/75p9lPbpiKpNFoTKYlFRcXQ6ms/IULgoCFCxdi2bJlGD16NJ5//vly+yQkJGDMmDEICAjA8uXLIZPpb1pHjx6N5557Dl5eXgCA0NBQ+Pn54amnnsLJkyfRtm1bs9t/J51OgFpdYPFxEokY7u5KqNWF0GprVtHKmgJ9nJF6qwAHT1xHhzD/Gp3r93366W8xoX4oLtSguLBmydm13WfSshg4S1UAlSq/WufIzNJ/FuQSUbXPYS32+hmzZ/beZ57O+kt3pqoQKTey4aK0fangmvSZu7vSat+W1bdrsT1jn1mOfWYZ9pfl6upabFYA8fPPP6Nbt26IiIi45379+vVDv379zHpi4PbUpYyMDAQHBxu3Z2RkICwsrMJjSkpKMHXqVGzZsgVTp07Fiy++WG6fbdu24c0330Tbtm2xdOlSk6lJYrHYGDwYtGrVCoB+SlV1AwgAKC2t/odbq9XV6HhrC2vqidRbBThzWYW2LXyrfR6dTsCBsvyHTq39rfoaa6vPnGT6Kkx5RSXVPr8hB8LdWWY376u9fcYcgb32mZNUAl8PBW7mFCH5hhrhzbyqPqiO2EOf1adrsSNgn1mOfWYZ9pflarvPzAozJk2ahK5du+KRRx7Bxx9/jH379lVYZtVS4eHhcHV1NSmxqlarkZiYiNjY2Erb8vvvv2PevHkVBg87d+7EhAkT0KtXL6xatapcXsOkSZPKHXfy5EkAQMuWLWv2guqRsGD9DUlNE6nPXFUhJ08DF4UUbUJ8rNCy2lfTJOqSUh3yCksAsIwr1Z7bC8oxkZqIiOqWWSMQ33//PQ4dOoTDhw9j8+bN+OKLL+Dk5IT27dsbV6Zu3bq1xU8ul8sxdOhQzJ07F97e3mjSpAnmzJmDgIAA9O/fH1qtFllZWXBzc4NCocAPP/yArVu3YtKkSYiLi0NmZqbxXG5ubiguLsbkyZMRGRmJadOmmeRRyGQyeHp6YsCAARgzZgyWLFmChx9+GMnJyXj//fcxaNAgtGjRwuLXUF8ZVqS+mp6LgqKSauctHDilH32Ibd0IUgdJglI66UcgCourtw6EoYSrTCqGi8KsPzEiizX1d8XR8zcZQBARUZ0z6+4mMjISkZGRxm/uL168iH///RdHjhzBhg0bjAFAfHw8unXrhscee8zsBowfPx6lpaWYPn06ioqKEBsbi1WrVkEmkyElJQV9+/bFrFmzMHjwYGzZsgUAMHv2bMyePdvkPLNmzYJMJoNarcbx48fLlXKNi4vDunXr0LdvXyxYsAArV67EZ599Bjc3Nzz00EN47bXXzG5zQ+Dp6oRGXkqkqwpxLiUHMS0tn8akKdHiyDl9kNfZztd+uJNhBKKwmiMQhgDC01VuUVEBIktwBIKIiGylWl+PtmjRAi1atMCzzz4LADh48CA2btyIP/74A1u3brUogJBIJJg4cSImTpxY7rGgoCCcPXvW+Pvq1aurPN9DDz1U5T4DBw7EwIEDzW5jQxUW7KUPIK5mVyuAOHbhJoo0Wvi4K9AyyKMWWlg7ahpAGPIfuIgc1SZDAHH9Zh5KtTqHGeEjIiLHV60AIisrC3v27MH+/ftx8OBBpKWlwdnZGd27dy+3DgM5rrBgT+w+fgNJV1XVOv7Aaf3aD50jG0HsQN/EK+VlAYSmmiMQXESO6oCvpxIKuQRFGi3SsgoQ5Odq6yYREVEDYVYAodVqcfToUezZswd79uxBUlISAP3UpkceeQTdunVDTEwMpFLO965Pwpp6AgCupOeisLjU+M28OfIKS3Dy0i0AQOfIgNpoXq2peQ6EvsAAE6ipNolFIgT5u+JCSg6uZeQxgCAiojpj1h1hp06dkJ+fj8DAQMTHx+Pll19Gly5d4OHhONNSyHLe7gr4eSqQmV2E8yk5iG5hfhWlQ0kZ0OoEBPu7oomvfS10VZU7pzAJgmBxHoPKmAPBAIJqV9M7Aoj4SFu3hoiIGgqzJs3m5eXBw8MDPXv2RK9evdC9e3cGDw3E7XKulk1jMqz94GijD8DtAEKrE1BSjRrKhilMXpzCRLWMidRERGQLZo1AfPfdd9izZw/27t2Lb7/9FgAQHR2Nbt26oVu3boiOjq7VRpLthDX1xN4TqTh7LdvsY25mF+J8Sg5EADo5UPUlAye5BCIAAvSjEPKyheXMpbqjChNRbQr2169zcy0918YtISKihsSsACIqKgpRUVEYPXo08vLysG/fPuzduxffffcdFi1aBE9PT3Tp0gXdunVD165d0aiR4900UsUM60FcTs1FkaYUCnnVH5kDifrk6fBmXg75LbxYJILCSYLCYi0KNVpYMtYmCIKxjKsjvnZyLE38XCASAeqCEuTkFcOD0+aIiKgOWJz17Orqiv79+6N///4A9GtCHDhwAAcPHsSMGTNQWlqKxMREqzeUbMPXQwlfDwVu5hThwvUcRN137zwIQRCMAYQjrf1wN6WTVB9AWFjKtbC4FJoS/bQn5kBQbXOSSdDIyxlpWQW4lpHHAIKIiOpEtQuHZ2dnY9euXdi8eTN+//137N27FzqdjtOZ6iFDNaazV7Or3PdaRh5u3MyHVCJGhzD/2m1YLTKUci2wMIBQlVVgclFILZ76RFQdzIMgIqK6ZvYIxOXLl5GQkGD8SU5OhiAIaNWqFeLj4zF8+HDExsbCxcWxKu5Q1UKDPfHPqTSzAoj9ZcnTMS194Kxw3LK+hkTqIgsDCOMaEPwmmOpIU39XHErKwFUGEEREVEfMusPr3LkzcnJyIAgCGjdujPj4eIwZMwbx8fHw8TG/tCc5pvCySkzJqWoUa7Rwklf8zbpOJ+CgYfqSA1ZfupMhgLB0BMKQ/8BF5KiuBDfiCAQREdUts9eB6NKlC+Lj4xEcHFzbbSI74+uhgLe7E7LUxbhwIweRzb0r3O/sVRWy8zRwdpKiTYhjB5aGxeSKLFxMTmUo4coRCKojTcsqMaXdKkBJqRYyKafOERFR7TIrgFi4cGFtt4PsmEgkQlhTT+w/nY6zV7MrDSD2l40+xLb2h0xa7fQau3DnYnKWuD0CwRKuVDc8XeVwVcqQV1iC6zfz0TzA3dZNIiKies6x7/KozhgWlDtXyYJyJaVaHDmbAcCxqy8ZGJKoCzUWJlFzBILqmEgkup1Inc5pTEREVPsYQJBZDOtBXEpVQ1NSflrP8Qu3UFishbe7E1qVVW1yZIYpTNUegWAAQXXIEEAwkZqIiOoCAwgyi7+nEp6ucpRqBVy8oS73uKH6UqeIRhCLRHXdPKu7nURtWQ5EdlkZVyZRU11iKVciIqpLDCDILCKRyDiN6exd05jyCktw4uItAEC8g1dfMqhOGVedTkCOIYDgCATVoTsDCEEQLDpWJwi4eD0H3/51Aeu3nUVxBSOMREREd3LcQv1U58KCPXEwMb3cehCHz2ZAqxMQ5OeKID9X2zTOyqqTRK0u0EAnCBCJAA8XJlFT3Wns6wKJWITC4lLcUhfB10N5z/1LtTokXVEh4fxNHD2faQx8AaCRlzPuj21a200mIiIHxgCCzGZYkfriDbVJucgDp/XVl+IjHT952kBZttZFocb8b2MNCdQeLnKIxY4/jYsch1QiRqCPC1Iy83AtPa/CAKKwuBQnL93C0fM3ceLiTRTeMT1PIZcg0McFyalq7Dx6HX07BtWLqYhERFQ7GECQ2QK8neHuIoc6X4NLN9QIC/bCrZwinLuWDRH0+Q/1hVJh+QiEIYHai/kPZAPBjVz1AURGHtqF+gEAcvI1OHY+E0fP30Ti5SyUam9Pb/JwkaNdK1+0C/VDeLAXSrU6vPHpP0jPKsCZK6pKyzUTERExgCCzGdaDOJSUgbPXshEW7IUDifrk6bBgT3i7K2zcQusxlnG1JIDIZQUmsh1DHkTSVRXkByVIOJ+Jiyk5uDMjopGXEu1D/dAu1A8hjd1NRhlkUjG6RgViR0IKdh5JYQBBRESVYgBBFgkPLgsgrmYDXYEDZYvHda4nydMGxiRqjRY6nWDWlCQVKzCRDd0OILKRdEee0n2BbmjXSh80NPZxhugeU5N6tW+CHQkpOHbhJrLURfXqSwEiIrIeBhBkkdCySkwXr+cgOVWN65n5kEpE6BjmZ+OWWZchgACAIk0pnBWyKo/hCATZ0n2B7nB3liGvsBRhwZ76kYZWvhYFAU18XRAe7Imkq9n469h1DO7RohZbTEREjooBBFmksY8z3JxlyC0owVc7zgMA2rbwNesG25HIpGJIJSKUagUUFmvNCyDyuAo12Y7SSYqPX+kCrU6As6L6l/Y+7YOQdDUbu4/dwMNd74NUwmrfRERkiv8ykEUMeRAAcCElBwDQuR5VX7qTpaVcVYZVqN1YwpVsw0kuqVHwAAAxrXzh6SqHuqAEh89mWKllRERUnzCAIIsZFpQD9DfZ0S18bNia2mNMpNaYF0AYpjBxBIIcmVQiRs+YJgCAXQnXbdwaIiKyRwwgyGKGEQgA6BjmZ1wPor6xZARCU6JFfpF+P5ZxJUfXo21jSMQinE/JwbWMPFs3h4iI7AwDCLJYYz8XuJettNwlqn5VX7qT0kkfGBWYEUAY8h/kUrFJAjaRI/JyczKuJbErIcXGrSEiInvDAIIsJhaJ8OrjbTDq4UiT6Uz1jbGUa3HVq1Fn31HC9V5lMokcRZ92+mlM+0+no6DI/PVQiIio/mMAQdXSorFHvVp5uiKWTGFSsYQr1TNhwZ5o7OuC4hIt9p1KtXVziIjIjjCAIKqEJUnUxhKuzH+gekIkEqF32SjErqPXIQhCFUcQEVFDwQCCqBJKhT4HorCo6ilMt0cgWMKV6o8uUQFwkkuQeqsASVdUtm4OERHZCQYQRJWo1ggEpzBRPaJ0kqJLpL5Qws6jLOlKRER6DCCIKmFJDoRhDQhPTmGieqZ3e/00pqPnbiJLXWTj1hARkT1gAEFUCUVZGVezAghDFSaOQFA9E+TnitCmntAJAnYfv2Hr5hARkR1gAEFUCWfjCMS9cyAEQYAqjyMQVH/1KRuF+PvYDZRqdTZuDRER2RoDCKJKmDuFqaC4FCWl+psqLyZRUz3UPtQPHi5y5ORrkHAu09bNISIiG2MAQVQJc5OoDRWYXBRSyKSSWm8XUV2TSsTo0bYxAGBXApOpiYgaOgYQRJUwdwSCa0BQQ9AzpjHEIhHOXstGSmaerZtDREQ2xACCqBLKsiTqUq1gnKJUEa5CTQ2Bt7sC7Vr5AtAvLEdERA0XAwiiSijKpjAB9x6FYAlXaigMydT7TqWZVZ2MiIjqJwYQRJUQi0VQyKsu5Woo4cpF5Ki+C2/mhUAfZxRrtNh/Os3WzSEiIhthAEF0D8Y8iHskUqs4AkENhEgkQu92+lGInQnXIQiCjVtERES2wACC6B6MAUTRvUYgypKoOQJBDUCXqEA4ySS4cTMf565l27o5RERkAwwgiO5BaZjCpKl8Mbnbi8hxDQiq/5wVUsRHNgKgH4UgIqKGhwEE0T1UVcpVq9NBnc8cCGpYercPAgAknMs0jsAREVHDwQCC6B4UZQFEQSUBhDq/BIIAiEUiuDlzBIIahqb+rmgV5AGtTsDuYzds3RwiIqpjDCCI7sG5bC2IokoCCMO3rx6ucojFojprF5Gt9S4r6frXseso1Va+TgoREdU/DCCI7uH2FKaKcyC4iBw1VB1C/eHuLEN2ngbHzt+0dXOIiKgO2TyA0Ol0WLRoEbp3746YmBi8/PLLuHbtWqX7nz9/HiNHjkSnTp0QHx+P8ePH48aN20PoRUVFmDdvHvr06YN27dph8ODB2LFjh8k5zpw5g6FDhyImJgZ9+vTB2rVra+31kWNTyu9dxtVYgYklXKmBkUnF6BHTGACwMyHFxq0hIqK6ZPMAYunSpdi4cSM++OADbNq0CTqdDiNGjIBGoym3r0qlwrBhw6BQKLBu3Tp89tlnyMrKwogRI1BcrL+R+/DDD/HLL7/g3XffxebNm9GvXz+MGzcOBw8eNDlHcHAwvv/+e4wdOxZz587F999/X6evmxxDVUnUt0cgmP9ADU/Ptk0gEgFJV7Nx/Wa+rZtDRER1xKYBhEajwerVqzF+/Hj06tUL4eHhmD9/PtLS0rBt27Zy+2/fvh0FBQWYPXs2QkNDERUVhTlz5uDixYtISEhAYWEhNm/ejNdffx09e/ZEs2bNMGbMGMTFxRkDhG+++QYymQzvv/8+WrRogccffxwvvvgiVq5cWdcvnxyAwsmwEnXFU5g4AkENmY+HAjEtfQEAOw5XPnJMRET1i00DiKSkJOTn5yM+Pt64zd3dHRERETh06FC5/ePj47F06VIoFArjNrFY/xLUajVEIhGWL1+OHj16mBwnFouhVqsBAIcPH0ZcXBykUqnx8c6dO+Py5cu4eZPzeMmUcxUjENnMgaAGrk9ZSde9J1Ir/TshIqL6RVr1LrUnLS0NABAYGGiy3d/f3/jYnYKCghAUFGSybeXKlVAoFIiNjYVCoUC3bt1MHj9x4gQOHDiA6dOnG58zNDS03PMBQGpqKnx9fav9eqRSy+MxiURs8l+qWl32mYtSBkCfA1HR+5udp59q5+OhqNb7Xxf4GbMc+8x8bVr6IMDbGWlZBdh9NAXxEY1s3SRei+sI+8xy7DPLsL8sV1d9ZtMAorCwEAAgl5vOH3dyckJOTk6Vx69btw7r16/H9OnT4e3tXe7xS5cuYezYsYiOjsZTTz0FQJ9kXdHzATDmUVSHWCyCl5dLtY93d1dW+9iGqi76rJGfPkAoLtFV+P5mly0i16yJZ43e/7rAz5jl2GfmeapfKBZ9cwyaEp3N+4zX4rrHPrMc+8wy7C/L1Xaf2TSAMExF0mg0JtOSiouLoVRW/sIFQcDChQuxbNkyjB49Gs8//3y5fRISEjBmzBgEBARg+fLlkMlkxue8O0HbEDg4OztX+7XodALU6gKLj5NIxHB3V0KtLoSWtdTNUpd9VqopAQDkF2qgUpkmiRaXaJFfqH9cIgjlHrcX/IxZjn1mmY6hvlj4WnfcF+RVrT5zd1da7dsyXovrDvvMcuwzy7C/LFeTPrPkWmzTAMIwdSkjIwPBwcHG7RkZGQgLC6vwmJKSEkydOhVbtmzB1KlT8eKLL5bbZ9u2bXjzzTfRtm1bLF26FG5ubsbHAgICkJGRYbK/4fdGjWo29F5aWv0Pt1arq9HxDVFd9Jm87A+pqFgLTYkWYtHtxeJuZutH0JxkEsgkIrt///gZsxz7zHxerk4QiUR20We8Ftct9pnl2GeWYX9Zrrb7zKaTysLDw+Hq6mossQrok6ETExMRGxtb4TGTJk3C77//jnnz5lUYPOzcuRMTJkxAr169sGrVKpPgAQBiY2Nx5MgRaLW3q+ocOHAA9913H3x8fKzzwqjeMJRxFQAUa0wrMWXfUcJVJOIq1ERERNQw2DSAkMvlGDp0KObOnYsdO3YgKSkJEyZMQEBAAPr37w+tVovMzEwUFRUBAH744Qds3boVEyZMQFxcHDIzM40/RUVFyMnJweTJkxEZGYlp06YhJyfH+Hh2djYA4PHHH0deXh6mTZuGCxcu4IcffsCaNWswatQoG/YE2SuZVAyJWB8c3F1hRsUSrkRERNQA2XQKEwCMHz8epaWlmD59OoqKihAbG4tVq1ZBJpMhJSUFffv2xaxZszB48GBs2bIFADB79mzMnj3b5DyzZs2CTCaDWq3G8ePHy5VyjYuLw7p16+Dj44PPP/8cM2fOxGOPPQY/Pz9MmjQJjz32WJ29ZnIcIpEISicp8gpLygUQ2bn6XBqWcCUiIqKGxOYBhEQiwcSJEzFx4sRyjwUFBeHs2bPG31evXl3l+R566KEq94mOjsbXX39tWUOpwVLIJWUBxF1TmMpGIDw5AkFEREQNCAvrElXBuJic5q4pTFxEjoiIiBogBhBEVVBWshp1NnMgiIiIqAFiAEFUhcoCCNUdVZiIiIiIGgoGEERVUDpJAMAkB0IQBGTn6ZOovTiFiYiIiBoQBhBEVVCUjUAU3DECkV9UitKyFR49GEAQERFRA8IAgqgKhiTqojsCCMP0JVelDDIp/4yIiIio4eCdD1EVKsqBMJZw5egDERERNTAMIIiqoJSX5UBobudAGEYgWIGJiIiIGhoGEERVuPcIBCswERERUcPCAIKoCoqKAgiOQBAREVEDxQCCqArOFY5A6Eu4ejKAICIiogaGAQRRFRSGHIgKqjAxiZqIiIgaGgYQRFUwjkDckURtyIHgInJERETU0Eht3QAie6dU6P9MSkp1xsXj1PmcwkREREQNEwMIoioYpjAB+mlMJaU6CAAkYhHcnGW2axgRERGRDXAKE1EVJGIxnGS38yBUZdOXPFzlEItEtmwaERERUZ1jAEFkBoWTIYDQIjtXP32J+Q9ERETUEDGAIDLDnaVcby8ixwCCiIiIGh4GEERmUMhvBxDGEq5MoCYiIqIGiAEEkRmcDVOYNHeOQMht2SQiIiIim2AAQWQGpXEKk9Y4AuHFEQgiIiJqgBhAEJlBwRwIIiIiIgAMIIjMUlESNUcgiIiIqCFiAEFkBsNictl5xSgs1gLgCAQRERE1TAwgiMxgGIG4casAAOAklxjzIoiIiIgaEgYQRGYw5ECklQUQXESOiIiIGioGEERmMIxAFJcYpi+xhCsRERE1TAwgiMxw93QlJlATERFRQ8UAgsgMirKF5AyYQE1EREQNFQMIIjM43zUC4ckRCCIiImqgGEAQmUEhv2sKE0cgiIiIqIFiAEFkBo5AEBEREekxgCAyg1wmhlgkMv7OKkxERETUUDGAIDKDSCSC8o5EaiZRExERUUPFAILITIZSrm7OMkgl/NMhIiKihol3QURmMiRSM4GaiIiIGjIGEERmci6bwsQEaiIiImrIGEAQmUlRNoWJ+Q9ERETUkDGAIDKTs8IQQLACExERETVc0qp3ISIA6N4mENm5xegU0cjWTSEiIiKyGQYQRGZq3dwbrZt727oZRERERDbFKUxERERERGQ2BhBERERERGQ2BhBERERERGQ2BhBERERERGQ2BhBERERERGQ2BhBERERERGQ2BhBERERERGQ2BhBERERERGQ2BhBERERERGQ2BhBERERERGQ2kSAIgq0bUR8IggCdrnpdKZGIodXqrNyi+o19Zhn2l+XYZ5arbp+JxSKIRCKrtIHX4rrFPrMc+8wy7C/L1cW1mAEEERERERGZjVOYiIiIiIjIbAwgiIiIiIjIbAwgiIiIiIjIbAwgiIiIiIjIbAwgiIiIiIjIbAwgiIiIiIjIbAwgiIiIiIjIbAwgiIiIiIjIbAwgiIiIiIjIbAwgiIiIiIjIbAwgiIiIiIjIbAwgiIiIiIjIbAwg7MA777yDKVOmlNu+f/9+DB48GG3btsUDDzyAX3/91Qats19HjhxBWFhYuZ+DBw/auml2Q6fTYdGiRejevTtiYmLw8ssv49q1a7Zult1KT0+v8DP1ww8/2LppdmnFihV4/vnnTbadOXMGQ4cORUxMDPr06YO1a9faqHWW47W4engtrhqvxZbj9dh8trgWS616NrKITqfDggUL8PXXX+Oxxx4zeezixYsYNWoUhg0bhjlz5uCvv/7CpEmT4O3tjfj4eBu12L6cPXsWwcHB2Lhxo8l2Dw8PG7XI/ixduhQbN27E//73PwQEBGDOnDkYMWIEfvnlF8jlcls3z+4kJSXByckJ27dvh0gkMm53c3OzYavs04YNG7BgwQJ07NjRuE2lUmHYsGHo06cP3nvvPRw7dgzvvfceXFxc8Pjjj9uwtffGa3HN8FpcNV6LLcfrsXlsdS1mAGEjFy9exLRp03DlyhU0bty43ONffvklwsLCMGHCBABAixYtkJiYiM8//5z/aJU5d+4cWrZsCT8/P1s3xS5pNBqsXr0ab775Jnr16gUAmD9/Prp3745t27Zh0KBBtm2gHTp37hyaN28Of39/WzfFbqWnp+Pdd9/FwYMH0bx5c5PHvvnmG8hkMrz//vuQSqVo0aIFrly5gpUrV9ptAMFrcc3xWnxvvBZXD6/H92brazGnMNnIgQMH0KJFC2zZsgVBQUHlHj98+HC5f5w6d+6MI0eOQBCEumqmXTt79ixatGhh62bYraSkJOTn55t8jtzd3REREYFDhw7ZsGX2i5+pqp0+fRoymQw///wz2rZta/LY4cOHERcXB6n09ndTnTt3xuXLl3Hz5s26bqpZeC2uOf7d3BuvxdXDz9W92fpazBEIGxkyZMg9H09LS0NAQIDJNn9/fxQWFkKlUsHb27s2m+cQzp8/Dy8vLwwePBjp6ekIDQ3FhAkTEB0dbeum2YW0tDQAQGBgoMl2f39/42Nk6ty5c/Dy8sKQIUOQnJyMZs2aYfTo0ejRo4etm2Y3+vTpgz59+lT4WFpaGkJDQ022Gb49TE1Nha+vb623z1K8Ftccr8X3xmtx9fB6fG+2vhYzgKgFKSkp6Nu3b6WP79+/v8p/dIqKisrNizT8rtFoat5IO1dVH/7111/Izc1FQUEBpk+fDolEgvXr12Po0KH44Ycf0LJlyzpsrX0qLCwEgHKfIycnJ+Tk5NiiSXattLQUly5dQsuWLTFlyhS4urri119/xciRI/HFF19wuooZKrpuOTk5AQCKi4vrvD28Ftccr8U1x2ux5Xg9rpm6uBYzgKgFjRo1wtatWyt93JzEMicnp3L/OBl+VyqVNWugA6iqD/39/XHo0CEolUrIZDIAQJs2bZCYmIh169bhvffeq6um2i2FQgFA/7kx/D+gv3g0hM+QpaRSKQ4ePAiJRGLsr6ioKJw/fx6rVq3iP1hmUCgU5a5bhn+snJ2d67w9vBbXHK/FNcdrseV4Pa6ZurgWM4CoBTKZrMbz9gIDA5GRkWGyLSMjA87Ozg2iAoE5feju7m7yu1gsRosWLZCenl6bTXMYhuHyjIwMBAcHG7dnZGQgLCzMVs2yay4uLuW2tWrVCnv37rVBaxxPQEBAhdctQH8jWtd4La45Xotrjtfi6uH1uPrq4lrMJGo71bFjR/z7778m2w4cOID27dtDLObbtnv3brRr186kjnZpaSmSkpI4ZF4mPDwcrq6uJrXY1Wo1EhMTERsba8OW2afz58+jffv25WrXnzp1ip8pM8XGxuLIkSPQarXGbQcOHMB9990HHx8fG7as+ngtvjdei6vGa7HleD2umbq4FvPqZ6eef/55nDhxAnPnzsXFixexevVq/P777xgxYoStm2YX2rdvDy8vL0yePBmnTp3C2bNnMXnyZGRnZ+PFF1+0dfPsglwux9ChQzF37lzs2LEDSUlJmDBhAgICAtC/f39bN8/utGjRAiEhIXj//fdx+PBhXLx4EbNmzcKxY8cwevRoWzfPITz++OPIy8vDtGnTcOHCBfzwww9Ys2YNRo0aZeumVRuvxffGa3HVeC22HK/HNVMX12JOYbJTrVq1wtKlSzFnzhx8+eWXCAoKwpw5czjvr4yrqyvWrFmDuXPnYvjw4SguLkaHDh2wfv16u6z0Yivjx49HaWkppk+fjqKiIsTGxmLVqlXGucp0m1gsxvLlyzFv3jy89tprUKvViIiIwBdffFGumgVVzMfHB59//jlmzpyJxx57DH5+fpg0aVK5xdkcCa/F98ZrsXl4LbYMr8c1UxfXYpHAQtZERERERGQmTmEiIiIiIiKzMYAgIiIiIiKzMYAgIiIiIiKzMYAgIiIiIiKzMYAgIiIiIiKzMYAgIiIiIiKzMYAgIiIiIiKzMYAgIiIiIiKzMYAgasAWL16MsLAwWzeDiIiIHAgDCCIiIiIiMhsDCCIiIiIiMhsDCCIbKyoqwrx589C/f39ERUWhffv2GDZsGM6cOVPpMS+99BIGDx5cbvuYMWPw8MMPG3//9ttvMXjwYMTExCA6OhqPPPIIfvvtt0rP26dPH0yZMsVk2w8//ICwsDCkpKQYt507dw6jRo1C+/bt0b59e4wdOxbXrl2z5GUTERGRg2IAQWRjkyZNwvfff4+RI0di9erVmDp1Ks6fP4833ngDgiBUeMzDDz+M06dP48qVK8ZtarUau3fvxiOPPAIA2LBhA9555x3069cPK1aswNy5cyGXy/Hmm28iLS2t2u1NTk7GM888g1u3buHjjz/GzJkzce3aNTz77LO4detWtc9LREREjoEBBJENaTQa5OfnY/r06XjiiScQFxeHJ598EsOHD8fFixdx8+bNCo/r378/nJ2dsWXLFuO2bdu2QavVYtCgQQCAa9euYfjw4RgzZgw6deqE/v37Y8aMGSgtLcWRI0eq3eYlS5ZAqVRizZo1uP/++zFw4ECsXbsWRUVF+Pzzz6t9XiIiInIMUls3gKghk8vlWLVqFQAgPT0dycnJuHz5Mnbt2gVAH2BUxNnZGf369cPWrVsxduxYAMCvv/6K+Ph4NGrUCACMU5HUajUuXbqEK1eu4ODBg/c8rzkOHDiAuLg4KBQKlJaWAgBcXV3RsWNH7Nu3r9rnJSIiIsfAAILIxvbs2YOPPvoIly5dgouLC8LDw+Hs7AwAlU5hAoBHHnkEP//8M5KSkuDr64uDBw/io48+Mj5+9epVvPPOO9i/fz9kMhlCQkIQHh5e5Xmrkp2dja1bt2Lr1q3lHvP29q72eYmIiMgxMIAgsqGrV69i7NixxjyFpk2bQiQSYcOGDdizZ889j42Pj4efnx9+++03+Pn5wcnJCf379wcA6HQ6jBw5EjKZDN999x1at24NqVSKCxcu4KeffrrnebVarcnvBQUFJr+7ubmhS5cuGDZsWLljpVJeUoiIiOo7/mtPZEOnTp1CcXExRo4cieDgYON2Q/Bwr5ECiUSChx56CLt27YK7uzv69etnHLlQqVRITk7GW2+9hTZt2hiP2b17NwB9gFERV1fXcgnWd+dLxMXF4cKFC8agxNDON998E82aNUPr1q3NfflERETkgBhAENlQZGQkpFIp5syZg5deegkajQY//PAD/vrrLwDlv/2/2yOPPILVq1dDLBbjs88+M2738fFBkyZNsGHDBgQEBMDd3R179uzB2rVrAQCFhYUVnq93795YsWIFVqxYgbZt22Lnzp04cOCAyT5jxozBM888g1GjRuHZZ5+Fk5MTvv76a2zfvh2LFi2qQW8QERGRI2AVJiIbatasGebNm4f09HSMHj0a77zzDgBg3bp1EIlEOHz48D2PDw8PR2hoKHx8fBAfH2/y2NKlS9GoUSNMmTIFr732Go4fP45ly5YhJCSk0vOOGjUKTz75JFatWoXRo0cjMzMTM2fOLPecGzZsgEgkwqRJkzB+/HhkZmbi008/NU6hIiIiovpLJNQkm5KIiIiIiBoUjkAQEREREZHZGEAQEREREZHZGEAQEREREZHZGEAQEREREZHZGEAQEREREZHZGEAQEREREZHZGEAQEREREZHZGEAQEREREZHZGEAQEREREZHZGEAQEREREZHZGEAQEREREZHZGEAQEREREZHZ/h9D0nxoL3QI0AAAAABJRU5ErkJggg==",
      "text/plain": [
       "<Figure size 800x400 with 2 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "fig, axs = plt.subplots(figsize=(8,4), ncols=2, sharey=True, sharex = True)\n",
    "axs[0].set_title(\"W2 distance, exact score, Euler scheme\")\n",
    "#axs[1].set_title(\"W2, exact score, Semi-Implicit scheme\")\n",
    "\n",
    "for k, scheme in enumerate([\"euler\"]):#, \"semii\"]):\n",
    "    for ns in num_steps_vec:\n",
    "        select = (distances_df[\"scheme\"] == scheme) & (distances_df[\"loss\"] == \"exact_score\") & (distances_df[\"num_steps\"]==ns)\n",
    "        res_kl = distances_df[select].groupby(\"a\")[\"w2\"].mean()\n",
    "        axs[k].plot(a_values, res_kl, label=f\"num_steps = {ns}\")\n",
    "    axs[k].legend()\n",
    "    axs[k].set_xlabel(\"a value\")\n",
    "\n",
    "axs[0].set_ylabel(\"W2 distance\")\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "a11d6af1-614c-43bb-9b45-eb5efec6b786",
   "metadata": {},
   "outputs": [],
   "source": [
    "######## SELECT TRAINING PARAMETERS W.R.T THE DIMENSION ######################\n",
    "if d==5: \n",
    "    # optimisation parameters\n",
    "    n_epochs = 2 \n",
    "    learning_rate = 1.0e-4\n",
    "    \n",
    "elif d==10: \n",
    "    # optimisation parameters\n",
    "    n_epochs = 30\n",
    "    learning_rate = 1.0e-4  \n",
    "    \n",
    "elif d==25: \n",
    "    # optimisation parameters\n",
    "    n_epochs = 50\n",
    "    learning_rate = 1.0e-3\n",
    "    \n",
    "elif d==50: \n",
    "    # optimisation parameters\n",
    "    n_epochs = 150\n",
    "    learning_rate = 1.0e-3\n",
    "\n",
    "# network_parameters\n",
    "mid_features = 256\n",
    "num_layers = 3\n",
    "batch_size = 64\n",
    "\n",
    "# Monte Carlo estimate samples (for E2 computation in the KL bound)\n",
    "num_mc = 500\n",
    "\n",
    "# number of runs (for replication)\n",
    "rep_size = 1\n",
    "\n",
    "# load data\n",
    "dataloader = DataLoader(training_sample_rescale, batch_size=batch_size, shuffle=True)\n",
    "\n",
    "# number of discretisation steps\n",
    "num_step = 500"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "d1b8c5c1-1289-4872-b9a5-5552b937210a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "optimization for a = -10.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Average Loss: 0.072744: 100%|█████████████████| 150/150 [00:52<00:00,  2.87it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "optimization for a = -9.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Average Loss: 0.065772: 100%|█████████████████| 150/150 [00:48<00:00,  3.11it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "optimization for a = -8.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Average Loss: 0.061582: 100%|█████████████████| 150/150 [00:47<00:00,  3.19it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "optimization for a = -7.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Average Loss: 0.066502: 100%|█████████████████| 150/150 [00:46<00:00,  3.21it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "optimization for a = -6.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Average Loss: 0.065252: 100%|█████████████████| 150/150 [00:47<00:00,  3.16it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "optimization for a = -5.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Average Loss: 0.068213: 100%|█████████████████| 150/150 [00:47<00:00,  3.18it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "optimization for a = -4.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Average Loss: 0.066723: 100%|█████████████████| 150/150 [00:47<00:00,  3.16it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "optimization for a = -3.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Average Loss: 0.073019: 100%|█████████████████| 150/150 [00:47<00:00,  3.14it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "optimization for a = -2.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Average Loss: 0.068168: 100%|█████████████████| 150/150 [00:49<00:00,  3.04it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "optimization for a = -1.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Average Loss: 0.073867: 100%|█████████████████| 150/150 [00:57<00:00,  2.63it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "optimization for a = 0.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Average Loss: 0.077701: 100%|█████████████████| 150/150 [00:57<00:00,  2.59it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "optimization for a = 1.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Average Loss: 0.060300: 100%|█████████████████| 150/150 [01:00<00:00,  2.48it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "optimization for a = 2.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Average Loss: 0.053844: 100%|█████████████████| 150/150 [00:55<00:00,  2.71it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "optimization for a = 3.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Average Loss: 0.054162: 100%|█████████████████| 150/150 [00:47<00:00,  3.19it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "optimization for a = 4.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Average Loss: 0.056525: 100%|█████████████████| 150/150 [00:47<00:00,  3.15it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "optimization for a = 5.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Average Loss: 0.057638: 100%|█████████████████| 150/150 [00:47<00:00,  3.15it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "optimization for a = 6.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Average Loss: 0.092249: 100%|█████████████████| 150/150 [00:47<00:00,  3.15it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "optimization for a = 7.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Average Loss: 0.076200: 100%|█████████████████| 150/150 [00:47<00:00,  3.13it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "optimization for a = 8.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Average Loss: 0.069318: 100%|█████████████████| 150/150 [00:47<00:00,  3.19it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "optimization for a = 9.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Average Loss: 0.088722: 100%|█████████████████| 150/150 [00:47<00:00,  3.19it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "optimization for a = 10.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Average Loss: 0.062789: 100%|█████████████████| 150/150 [00:47<00:00,  3.18it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 49min 41s, sys: 1h 14min 17s, total: 2h 3min 59s\n",
      "Wall time: 17min 23s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "\"\\n##### TO LOAD PRETRAINED MODELS\\nscore_theta_trained = [ [] for _ in a_values ]   # a_values, replicates\\nfor k, a in enumerate(a_values):\\n    score_theta_trained[k] = [ ]\\n    for j in range(rep_size):\\n        network = decoder.Decoder(sde.d, mid_features, num_layers) \\n        network.load_state_dict(torch.load(f'models/gaussian_aniso/d50_explicit_rescale_ep150'+f'/model_{k}_{j}_d'+str(d)+'.pt'))\\n        score_theta_trained[k].append(network)\\n\""
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "%%time\n",
    "###\n",
    "## TRAINING NN FOR DIFFERENT NOISE SCHEDULES\n",
    "###\n",
    "\n",
    "score_theta_trained = [ [] for _ in a_values ]   # a_values, replicates\n",
    "for k, a in enumerate(a_values):\n",
    "    print(f\"optimization for a = {a}\")\n",
    "    sde.beta.change_a(a)\n",
    "    score_theta_trained[k] = [ ]\n",
    "    for j in range(rep_size):\n",
    "        network = decoder.Decoder(sde.d, mid_features, num_layers).to(device)\n",
    "        optimizer = Adam(network.parameters(), lr=learning_rate)\n",
    "        #loss = diff.loss_conditional(network, sde) # to choose when the score function is not analytically available\n",
    "        loss = diff.loss_explicit(network, sde, diff.explicit_score(sde, dist)) # to choose in the Gaussian case for training with the analytical score function\n",
    "        diff.train(loss, dataloader, n_epochs=n_epochs, optimizer=optimizer)\n",
    "        score_theta_trained[k].append(network)\n",
    "        #torch.save(network.state_dict(), f'models/gaussian_aniso/d50_explicit_rescale_ep150'+f'/model_{k}_{j}_d'+str(d)+'.pt')\n",
    "\n",
    "'''\n",
    "##### TO LOAD PRETRAINED MODELS\n",
    "score_theta_trained = [ [] for _ in a_values ]   # a_values, replicates\n",
    "for k, a in enumerate(a_values):\n",
    "    score_theta_trained[k] = [ ]\n",
    "    for j in range(rep_size):\n",
    "        network = decoder.Decoder(sde.d, mid_features, num_layers) \n",
    "        network.load_state_dict(torch.load(f'models/gaussian_aniso/d50_explicit_rescale_ep150'+f'/model_{k}_{j}_d'+str(d)+'.pt'))\n",
    "        score_theta_trained[k].append(network)\n",
    "'''"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "1294d023-5f43-4888-a157-773108757a06",
   "metadata": {},
   "outputs": [],
   "source": [
    "num_steps = 500"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bb6be0ad-3d92-4d7e-9943-f583cdb75524",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "0it [00:00, ?it/s]/Users/strasman/Desktop/CODE_UNIFORMISE/functions.py:454: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  rev_tk =  torch.tensor(sde.final_time - times[i], device = sde.device)\n",
      "/Users/strasman/Desktop/CODE_UNIFORMISE/functions.py:455: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  rev_tkp1 = torch.tensor(sde.final_time - times[i+1], device = sde.device)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "simulation and M for a = -10.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "1it [00:09,  9.51s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "simulation and M for a = -9.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2it [00:19,  9.76s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "simulation and M for a = -8.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "3it [00:29,  9.99s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "simulation and M for a = -7.0\n"
     ]
    }
   ],
   "source": [
    "### GAUSSIAN CASE\n",
    "\n",
    "simulation_results = []\n",
    "\n",
    "for k, a in tqdm(enumerate(a_values)):\n",
    "    print(f\"simulation and M for a = {a}\")\n",
    "    sde.beta.change_a(a)\n",
    "    exp_score = diff.explicit_score(sde, dist)\n",
    "\n",
    "    for j, score_theta in enumerate(score_theta_trained[k]):\n",
    "        _, error_approx_sup_L2 = func.compute_E2(dist, sde, score_theta, exp_score, num_steps, num_mc)\n",
    "\n",
    "        for scheme_name, scheme_sampler in [(\"euler\", sp.Euler_Maruyama_discr_sampler)]:#, (\"semii\", sp.EI_discr_sampler)]:\n",
    "            sample = scheme_sampler(init, sde, score_theta, num_steps)\n",
    "            sample = func.unnormalize(sample, mean, std,rescale)\n",
    "            w2_value = w2(training_distribution, empirical(sample))#func.unnormalize(sample, mean, std ,rescale)))\n",
    "            simulation_results.append({\n",
    "                \"a\": a,\n",
    "                \"replication\": j,\n",
    "                \"scheme\": scheme_name,\n",
    "                \"error_approx_sup_L2\": error_approx_sup_L2,\n",
    "                #\"kl\": kl_value,\n",
    "                \"w2\": w2_value,\n",
    "            })\n",
    "\n",
    "simulation_df = pd.DataFrame(simulation_results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fe7f441a-8145-474c-8daa-6c37e903a6c6",
   "metadata": {},
   "outputs": [],
   "source": [
    "#save\n",
    "#file_path = os.path.join('models/gaussian_aniso/d50_explicit_rescale_ep150_pkl', 'simulation_df.pkl')\n",
    "#simulation_df.to_pickle(file_path)\n",
    "#load \n",
    "#file_path_load = 'models/gaussian_aniso/d50_explicit_rescale_ep150_pkl/simulation_df.pkl'\n",
    "#simulation_df = pd.read_pickle(file_path_load)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "56c8fc83-70d9-4b03-b8ba-eaade8ec0c45",
   "metadata": {},
   "outputs": [],
   "source": [
    "epsilon = simulation_df.groupby('a')['error_approx_sup_L2'].mean()\n",
    "error_results = []\n",
    "\n",
    "for k, a in tqdm(enumerate(a_values)):\n",
    "    print(f\"bound for a = {a}\")\n",
    "    sde.beta.change_a(a)\n",
    "    #error_mixing_E1 = func.compute_E1(training_distribution, sde)\n",
    "    #error_discr_E3 = func.compute_E3(training_distribution, sde, num_steps)\n",
    "    w2_error = func.compute_w2_bound(dist, training_sample_rescale, sde, num_steps, epsilon[a])\n",
    "    #w2_error_emp = func.compute_w2_bound_empirical(training_sample_rescale, sde, num_steps, epsilon[a])\n",
    "\n",
    "    error_results.append({\n",
    "        \"a\": a,\n",
    "        #\"error_mixing_E1\": error_mixing_E1,\n",
    "        #\"error_discr_E3\": error_discr_E3,\n",
    "        #\"w2_error_emp\": w2_error_emp,\n",
    "        \"w2_error\": w2_error,\n",
    "\n",
    "    })\n",
    "\n",
    "errors_df = pd.DataFrame(error_results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c1db49e1-28a6-446f-943c-3b87bf3efbe9",
   "metadata": {},
   "outputs": [],
   "source": [
    "#save\n",
    "#file_path = os.path.join('models/gaussian_aniso/d50_explicit_rescale_ep150_pkl', 'errors_df.pkl')\n",
    "#errors_df.to_pickle(file_path)\n",
    "#load \n",
    "file_path_load = 'models/gaussian_aniso/d50_explicit_rescale_ep150_pkl/errors_df.pkl'\n",
    "errors_df = pd.read_pickle(file_path_load)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "831d0336-0c95-494a-a4de-6c5c36d7868e",
   "metadata": {},
   "outputs": [],
   "source": [
    "results_df = pd.merge(simulation_df, errors_df, on=\"a\", how=\"left\")\n",
    "\n",
    "# computation of the W2 upperbound\n",
    "W2_upperbound = results_df.groupby('a')['w2_error'].mean() * 1/rescale\n",
    "\n",
    "# generation results W2 \n",
    "mean_all_Euler_W2 = results_df[results_df['scheme'] == 'euler'].groupby('a')['w2'].mean()\n",
    "std_all_Euler_W2 = results_df[results_df['scheme'] == 'euler'].groupby('a')['w2'].std()\n",
    "mean_all_EI_W2 = results_df[results_df['scheme'] == 'semii'].groupby('a')['w2'].mean()\n",
    "std_all_EI_W2 = results_df[results_df['scheme'] == 'semii'].groupby('a')['w2'].std()\n",
    "\n",
    "# generation results VPSDE W2\n",
    "\n",
    "VPSDE_mean_W2 = mean_all_Euler_W2[0.0]\n",
    "VPSDE_std_W2 = std_all_Euler_W2[0.0]\n",
    "\n",
    "# Exact score results W2 \n",
    "true_score_W2_Euler = distances_df[(distances_df[\"scheme\"] == 'euler') & (distances_df[\"loss\"] == \"exact_score\") & (distances_df[\"num_steps\"]==num_steps)]['w2']\n",
    "true_score_W2_EI = distances_df[(distances_df[\"scheme\"] == 'semii') & (distances_df[\"loss\"] == \"exact_score\") & (distances_df[\"num_steps\"]==num_steps)]['w2']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7d14cc76-57b9-4c49-b27b-380e7432be98",
   "metadata": {},
   "outputs": [],
   "source": [
    "##### PLOT WASSERSTEIN ########\n",
    "font = {'family' : 'sans-serif',\n",
    "        'size'   : 16}\n",
    "plt.rc('font', **font)\n",
    "cmap = cm.get_cmap('tab20c')\n",
    "mpl.rcParams['axes.facecolor'] = 'white'\n",
    "mpl.rcParams['axes.grid'] = False\n",
    "mpl.rcParams['axes.edgecolor'] = 'gray'\n",
    "\n",
    "fig, ax1 = plt.subplots()\n",
    "\n",
    "y1_color = cmap(1/20)\n",
    "ax1.plot(a_values, W2_upperbound, color=y1_color, label=\"Upper-bound\",linewidth=3)\n",
    "ax1.fill_between(a_values,\n",
    "                 ([x for x in W2_upperbound]) - simulation_df.groupby('a')['error_approx_sup_L2'].std().to_numpy() ,\n",
    "                 ([x for x in W2_upperbound]) + simulation_df.groupby('a')['error_approx_sup_L2'].std().to_numpy() ,\n",
    "                 fc=y1_color,alpha=.3)\n",
    "ax1.set_xlabel(\"Values of a\")\n",
    "ax1.set_ylabel(\"Upper bound (W2)\", color=y1_color)\n",
    "ax1.tick_params(axis='y', labelcolor=y1_color)\n",
    "\n",
    "y2_color = cmap(5/20)\n",
    "y3_color = cmap(5/20)\n",
    "color_0 = 'gray' #mean_all_Euler_KL\n",
    "color_bis = 'black'\n",
    "ax2 = ax1.twinx()\n",
    "\n",
    "ax2.plot(a_values, mean_all_Euler_W2, color=y2_color, label= r\"$\\mathcal{W}_2(\\pi_{\\rm data},\\hat{\\pi})$ (NN) \",linewidth=3)\n",
    "#ax2.plot(a_values, true_score_W2_Euler, color=y2_color, label = r\"$\\mathcal{W}_2(\\pi_{\\rm data}$,$\\hat{\\pi}$) (exact score)\",linewidth=3, linestyle='dotted')\n",
    "ax2.fill_between(a_values,\n",
    "                 mean_all_Euler_W2-std_all_Euler_W2,\n",
    "                 mean_all_Euler_W2+std_all_Euler_W2,\n",
    "                 fc=y2_color,alpha=.3)\n",
    "ax2.plot(a_values, np.full_like(a_values, VPSDE_mean_W2), color=y3_color, label=\"VPSDE (NN)\",linewidth=2, linestyle='--')\n",
    "ax2.fill_between(a_values,\n",
    "                 VPSDE_mean_W2-VPSDE_std_W2,\n",
    "                 VPSDE_mean_W2+VPSDE_std_W2,\n",
    "                 fc=y3_color,alpha=.2)\n",
    "ax2.set_ylabel(r\"$\\mathcal{W}_2$ distance\", color=y2_color)\n",
    "ax2.tick_params(axis='y', labelcolor=y2_color)\n",
    "\n",
    "fig.tight_layout()\n",
    "\n",
    "#plt.legend(loc = 'upper center')\n",
    "plt.savefig(\"RESCALED_HETERO.pdf\")\n",
    "plt.show() "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a83692c3-2f5e-4b5f-9557-9dd7a6ff2b9b",
   "metadata": {},
   "outputs": [],
   "source": [
    "##### PLOT WASSERSTEIN ########\n",
    "font = {'family' : 'sans-serif',\n",
    "        'size'   : 16}\n",
    "plt.rc('font', **font)\n",
    "cmap = cm.get_cmap('tab20c')\n",
    "mpl.rcParams['axes.facecolor'] = 'white'\n",
    "mpl.rcParams['axes.grid'] = False\n",
    "mpl.rcParams['axes.edgecolor'] = 'gray'\n",
    "\n",
    "#fig, ax1 = plt.subplots()\n",
    "\n",
    "y1_color = cmap(1/20)\n",
    "plt.plot(a_values, W2_upperbound, color=y1_color, label=\"Upper-bound\",linewidth=3)\n",
    "plt.fill_between(a_values,\n",
    "                 ([x for x in W2_upperbound]) - simulation_df.groupby('a')['error_approx_sup_L2'].std().to_numpy() ,\n",
    "                 ([x for x in W2_upperbound]) + simulation_df.groupby('a')['error_approx_sup_L2'].std().to_numpy() ,\n",
    "                 fc=y1_color,alpha=.3)\n",
    "#ax1.set_xlabel(\"Values of a\")\n",
    "#ax1.set_ylabel(\"Upper bound (W2)\", color=y1_color)\n",
    "#ax1.tick_params(axis='y', labelcolor=y1_color)\n",
    "\n",
    "y2_color = cmap(5/20)\n",
    "y3_color = cmap(5/20)\n",
    "color_0 = 'gray' #mean_all_Euler_KL\n",
    "color_bis = 'black'\n",
    "ax2 = ax1.twinx()\n",
    "\n",
    "plt.plot(a_values, mean_all_Euler_W2, color=y2_color, label= r\"$\\mathcal{W}_2(\\pi_{\\rm data},\\hat{\\pi})$ (NN) \",linewidth=3)\n",
    "plt.plot(a_values, true_score_W2_Euler, color=y2_color, label = r\"$\\mathcal{W}_2(\\pi_{\\rm data}$,$\\hat{\\pi}$) (exact score)\",linewidth=3, linestyle='dotted')\n",
    "plt.fill_between(a_values,\n",
    "                 mean_all_Euler_W2-std_all_Euler_W2,\n",
    "                 mean_all_Euler_W2+std_all_Euler_W2,\n",
    "                 fc=y2_color,alpha=.3)\n",
    "plt.plot(a_values, np.full_like(a_values, VPSDE_mean_W2), color=y3_color, label=\"VPSDE (NN)\",linewidth=2, linestyle='--')\n",
    "plt.fill_between(a_values,\n",
    "                 VPSDE_mean_W2-VPSDE_std_W2,\n",
    "                 VPSDE_mean_W2+VPSDE_std_W2,\n",
    "                 fc=y3_color,alpha=.2)\n",
    "#plt.set_ylabel(r\"$\\mathcal{W}_2$ distance\", color=y2_color)\n",
    "plt.tick_params(axis='y', labelcolor=y2_color)\n",
    "\n",
    "#fig.tight_layout()\n",
    "\n",
    "#plt.legend(loc = 'upper center')\n",
    "#plt.savefig(\"RESCALED_HETERO.pdf\")\n",
    "plt.show() "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "82b74bf9-1b38-45a0-96d1-229f6cf457a0",
   "metadata": {},
   "outputs": [],
   "source": [
    "W2_upperbound2 = errors_df2.groupby('a')['w2_error'].mean() * 1/rescale\n",
    "plt.plot(a_values, W2_upperbound2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "148c847f-8323-4647-a410-7f4c6463b345",
   "metadata": {},
   "outputs": [],
   "source": [
    "epsilon = simulation_df.groupby('a')['error_approx_sup_L2'].mean()\n",
    "error_results2 = []\n",
    "\n",
    "for k, a in tqdm(enumerate(a_values)):\n",
    "    print(f\"bound for a = {a}\")\n",
    "    sde.beta.change_a(a)\n",
    "    #error_mixing_E1 = func.compute_E1(training_distribution, sde)\n",
    "    #error_discr_E3 = func.compute_E3(training_distribution, sde, num_steps)\n",
    "    w2_error = func.compute_w2_bound(dist, training_sample_rescale, sde, num_steps, epsilon[a])\n",
    "    #w2_error_emp = func.compute_w2_bound_empirical(training_sample_rescale, sde, num_steps, epsilon[a])\n",
    "\n",
    "    error_results2.append({\n",
    "        \"a\": a,\n",
    "        #\"error_mixing_E1\": error_mixing_E1,\n",
    "        #\"error_discr_E3\": error_discr_E3,\n",
    "        #\"w2_error_emp\": w2_error_emp,\n",
    "        \"w2_error\": w2_error,\n",
    "\n",
    "    })\n",
    "\n",
    "errors_df2 = pd.DataFrame(error_results2)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7e6bebfd-4b2e-4354-958b-c644e030bb89",
   "metadata": {},
   "source": [
    "### CORRELATED EXPLICIT"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b1bdce6d-5ed5-4faa-a75c-50a2cbe57d25",
   "metadata": {},
   "outputs": [],
   "source": [
    "########### SET Training data PARAMETERS #################\n",
    "n = 10000\n",
    "d = 50 \n",
    "mu = torch.ones(d, device = device)\n",
    "\n",
    "### Correlated case\n",
    "SIGMA = torch.eye(d, device=device)\n",
    "for i in range(d):\n",
    "        for j in range(d):\n",
    "            if i != j:\n",
    "                SIGMA[i, j] = 1 / np.sqrt(np.abs(j-i)+1)\n",
    "\n",
    "\n",
    "training_distribution = func.gaussian(d, mu, SIGMA)\n",
    "training_sample = training_distribution.generate_sample(n)\n",
    "\n",
    "########### SET DIFFUSION AND NOISING PARAMETERS #################\n",
    "sigma_inv = 1\n",
    "beta_min = 0.1\n",
    "beta_max = 20\n",
    "a = 0.\n",
    "T = 1\n",
    "\n",
    "# Set noising function and parameters\n",
    "a_values = np.linspace(-10.,10.,21) #set the parameter values for the beta_parametric function\n",
    "beta = func.beta_parametric(a, T, beta_min, beta_max)\n",
    "#beta = func.beta_cosine(T, 0.003)\n",
    "\n",
    "sde = diff.forward_VPSDE(d, beta, sigma_inv, T, device = device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b11ac12c-87fd-4b14-957c-3ec8fb834625",
   "metadata": {},
   "outputs": [],
   "source": [
    "'''\n",
    "########### SET DIFFUSION AND NOISING PARAMETERS #################\n",
    "sigma_inv = 1.\n",
    "beta_min = 0.1\n",
    "beta_max = 20\n",
    "a = torch.tensor(0.)\n",
    "T = 1\n",
    "\n",
    "# Set noising function and parameters\n",
    "a_values = np.linspace(-10.,10.,21) #set the parameter values for the beta_parametric function\n",
    "beta = func.beta_parametric(a, T, beta_min, beta_max)\n",
    "#beta = func.beta_cosine(T, 0.003)\n",
    "\n",
    "sde = diff.forward_VPSDE(d, beta, sigma_inv, T, device = device)\n",
    "'''"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "32feca7c-e2a5-441c-b1c2-c06689a05897",
   "metadata": {},
   "outputs": [],
   "source": [
    "################################# SET SAMPLING PARAMETERS #################################\n",
    "sample_batch_size = 10000 # size of the sample generated\n",
    "#init = sde.final.generate_sample(sample_batch_size)\n",
    "num_steps = 500 \n",
    "#xbarT_euler = sp.Euler_Maruyama_discr_sampler(init, sde, diff.explicit_score(sde, training_distribution), num_steps)\n",
    "#xbarT_semii = sp.EI_discr_sampler(init, sde, diff.explicit_score(sde, training_distribution), num_steps)\n",
    "#num_steps_vec = np.array([num_steps//4, num_steps//2, num_steps ]) #to adjust as desired"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "89fe7360-f284-4b3e-9bb5-e127d45da2e6",
   "metadata": {},
   "outputs": [],
   "source": [
    "##### EXACT SCORE SIMULATION GAUSSIAN CASE ####\n",
    "num_steps_vec = np.array([num_steps]) #to adjust as desired\n",
    "'''\n",
    "distances = []\n",
    "for i, ns in tqdm(enumerate(num_steps_vec)):\n",
    "    for j, a in enumerate(a_values):\n",
    "        sde.beta.change_a(a)\n",
    "        score_theta = diff.explicit_score(sde, training_distribution)\n",
    "        for k, scheme in enumerate([sp.Euler_Maruyama_discr_sampler, sp.EI_discr_sampler]):\n",
    "            sample = scheme(init, sde, score_theta, ns)\n",
    "            distances.append({\n",
    "                \"a\": a, \"num_steps\": ns, \n",
    "                \"scheme\": \"euler\" if k==0 else \"semii\", \n",
    "                \"loss\": \"exact_score\", \n",
    "                \"kl\": kl(training_distribution, empirical(sample)), \n",
    "                \"w2\": w2(training_distribution, empirical(sample)),\n",
    "            })\n",
    "distances_df = pd.DataFrame(distances)\n",
    "'''"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f17f5afb-acb3-4394-88da-57c946317b37",
   "metadata": {},
   "outputs": [],
   "source": [
    "#save\n",
    "#file_path = os.path.join('models/gaussian_covar/d50_explicit_pkl', 'distances_df.pkl')\n",
    "#distances_df.to_pickle(file_path)\n",
    "\n",
    "#load \n",
    "file_path_load = 'models/gaussian_covar/d50_explicit_pkl/distances_df.pkl'\n",
    "distances_df = pd.read_pickle(file_path_load)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1a628a68-38a3-41f9-a107-dad76b6dfba8",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axs = plt.subplots(figsize=(8,4), ncols=2, sharey=True, sharex = True)\n",
    "axs[0].set_title(\"KL divergence, exact score, Euler scheme\")\n",
    "axs[1].set_title(\"KL divergence, exact score, Semi-Implicit scheme\")\n",
    "\n",
    "for k, scheme in enumerate([\"euler\", \"semii\"]):\n",
    "    for ns in num_steps_vec:\n",
    "        select = (distances_df[\"scheme\"] == scheme) & (distances_df[\"loss\"] == \"exact_score\") & (distances_df[\"num_steps\"]==ns)\n",
    "        res_kl = distances_df[select].groupby(\"a\")[\"kl\"].mean()\n",
    "        axs[k].plot(a_values, res_kl, label=f\"num_steps = {ns}\")\n",
    "    axs[k].legend()\n",
    "    axs[k].set_xlabel(\"a value\")\n",
    "\n",
    "axs[0].set_ylabel(\"KL divergence\")\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "077b19de-5be3-4b56-93a5-f903dd105a2d",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axs = plt.subplots(figsize=(8,4), ncols=2, sharey=True, sharex = True)\n",
    "axs[0].set_title(\"W2 distance, exact score, Euler scheme\")\n",
    "axs[1].set_title(\"W2, exact score, Semi-Implicit scheme\")\n",
    "\n",
    "for k, scheme in enumerate([\"euler\", \"semii\"]):\n",
    "    for ns in num_steps_vec:\n",
    "        select = (distances_df[\"scheme\"] == scheme) & (distances_df[\"loss\"] == \"exact_score\") & (distances_df[\"num_steps\"]==ns)\n",
    "        res_kl = distances_df[select].groupby(\"a\")[\"w2\"].mean()\n",
    "        axs[k].plot(a_values, res_kl, label=f\"num_steps = {ns}\")\n",
    "    axs[k].legend()\n",
    "    axs[k].set_xlabel(\"a value\")\n",
    "\n",
    "axs[0].set_ylabel(\"W2 distance\")\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "26373db4-b61b-456a-995b-6bf529138334",
   "metadata": {},
   "outputs": [],
   "source": [
    "######## SELECT TRAINING PARAMETERS W.R.T THE DIMENSION ######################\n",
    "if d==5: \n",
    "    # optimisation parameters\n",
    "    n_epochs = 2 \n",
    "    learning_rate = 1.0e-4\n",
    "    \n",
    "elif d==10: \n",
    "    # optimisation parameters\n",
    "    n_epochs = 30\n",
    "    learning_rate = 1.0e-4  \n",
    "    \n",
    "elif d==25: \n",
    "    # optimisation parameters\n",
    "    n_epochs = 50\n",
    "    learning_rate = 1.0e-3\n",
    "    \n",
    "elif d==50: \n",
    "    # optimisation parameters\n",
    "    n_epochs = 150\n",
    "    learning_rate = 1.0e-3\n",
    "\n",
    "# network_parameters\n",
    "mid_features = 256\n",
    "num_layers = 3\n",
    "batch_size = 64\n",
    "\n",
    "# Monte Carlo estimate samples (for E2 computation in the KL bound)\n",
    "num_mc = 500\n",
    "\n",
    "# number of runs (for replication)\n",
    "rep_size = 10\n",
    "\n",
    "# load data\n",
    "dataloader = DataLoader(training_sample, batch_size=batch_size, shuffle=True)\n",
    "\n",
    "# number of discretisation steps\n",
    "num_step = 500"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4e884dde-c288-4ac2-9769-83986e454ffb",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "###\n",
    "## TRAINING NN FOR DIFFERENT NOISE SCHEDULES\n",
    "###\n",
    "'''\n",
    "score_theta_trained = [ [] for _ in a_values ]   # a_values, replicates\n",
    "for k, a in enumerate(a_values):\n",
    "    print(f\"optimization for a = {a}\")\n",
    "    sde.beta.change_a(a)\n",
    "    score_theta_trained[k] = [ ]\n",
    "    for j in range(rep_size):\n",
    "        network = decoder.Decoder(sde.d, mid_features, num_layers).to(device)\n",
    "        optimizer = Adam(network.parameters(), lr=learning_rate)\n",
    "        loss = diff.loss_conditional(network, sde) # to choose when the score function is not analytically available\n",
    "        #loss = diff.loss_explicit(network, sde, diff.explicit_score(sde, training_distribution)) # to choose in the Gaussian case for training with the analytical score function\n",
    "        diff.train(loss, dataloader, n_epochs=n_epochs, optimizer=optimizer)\n",
    "        score_theta_trained[k].append(network)\n",
    "        #torch.save(network.state_dict(), f'model/d'+str(d)+f'/model_{k}_{j}_d'+str(d)+'.pt')\n",
    "\n",
    "'''\n",
    "##### TO LOAD PRETRAINED MODELS\n",
    "score_theta_trained = [ [] for _ in a_values ]   # a_values, replicates\n",
    "for k, a in enumerate(a_values):\n",
    "    score_theta_trained[k] = [ ]\n",
    "    for j in range(rep_size):\n",
    "        network = decoder.Decoder(sde.d, mid_features, num_layers) \n",
    "        network.load_state_dict(torch.load(f'models/gaussian_covar/d'+str(d)+f'_explicit_ep'+str(n_epochs)+f'/model_{k}_{j}_d'+str(d)+'.pt'))\n",
    "        score_theta_trained[k].append(network)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4beef222-e5fd-425d-bfb3-9a5ad4e2ee94",
   "metadata": {},
   "outputs": [],
   "source": [
    "### GAUSSIAN CASE\n",
    "\n",
    "simulation_results = []\n",
    "\n",
    "for k, a in tqdm(enumerate(a_values)):\n",
    "    print(f\"simulation and M for a = {a}\")\n",
    "    sde.beta.change_a(a)\n",
    "    exp_score = diff.explicit_score(sde, training_distribution)\n",
    "\n",
    "    for j, score_theta in enumerate(score_theta_trained[k]):\n",
    "        error_approx_E2, error_approx_sup_L2 = func.compute_E2(training_distribution, sde, score_theta, exp_score, num_steps, num_mc)\n",
    "\n",
    "        for scheme_name, scheme_sampler in [(\"euler\", sp.Euler_Maruyama_discr_sampler), (\"semii\", sp.EI_discr_sampler)]:\n",
    "            sample = scheme_sampler(init, sde, score_theta, num_steps)\n",
    "            kl_value = kl(training_distribution, empirical(sample))\n",
    "            w2_value = w2(training_distribution, empirical(sample))\n",
    "\n",
    "            simulation_results.append({\n",
    "                \"a\": a,\n",
    "                \"replication\": j,\n",
    "                \"scheme\": scheme_name,\n",
    "                \"error_approx_E2\": error_approx_E2,\n",
    "                \"error_approx_sup_L2\": error_approx_sup_L2,\n",
    "                \"kl\": kl_value,\n",
    "                \"w2\": w2_value,\n",
    "            })\n",
    "\n",
    "simulation_df = pd.DataFrame(simulation_results)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7405097e-fdeb-4d25-bf15-6fb847b755fe",
   "metadata": {},
   "outputs": [],
   "source": [
    "#save\n",
    "#file_path = os.path.join('models/gaussian_covar/d50_explicit_pkl', 'simulation_df.pkl')\n",
    "#simulation_df.to_pickle(file_path)\n",
    "#load \n",
    "file_path_load = 'models/gaussian_covar/d50_explicit_pkl/simulation_df.pkl'\n",
    "simulation_df = pd.read_pickle(file_path_load)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a38f5d8d-91d8-40d6-9294-1c2b65dee8f5",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7a9f3a14-e51f-44e7-a4df-f77534564310",
   "metadata": {},
   "outputs": [],
   "source": [
    "epsilon = simulation_df.groupby('a')['error_approx_sup_L2'].mean()\n",
    "error_results = []\n",
    "\n",
    "for k, a in tqdm(enumerate(a_values)):\n",
    "    print(f\"bound for a = {a}\")\n",
    "    sde.beta.change_a(a)\n",
    "    error_mixing_E1 = func.compute_E1(training_distribution, sde)\n",
    "    error_discr_E3 = func.compute_E3(training_distribution, sde, num_steps)\n",
    "    w2_error = func.compute_w2_bound(training_distribution, training_sample, sde, num_steps, epsilon[a])\n",
    "\n",
    "    error_results.append({\n",
    "        \"a\": a,\n",
    "        \"error_mixing_E1\": error_mixing_E1,\n",
    "        \"error_discr_E3\": error_discr_E3,\n",
    "        \"w2_error\": w2_error,\n",
    "    })\n",
    "\n",
    "errors_df = pd.DataFrame(error_results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7b40a487-ea22-4a82-ada9-2ec6a2029db9",
   "metadata": {},
   "outputs": [],
   "source": [
    "#save\n",
    "#file_path = os.path.join('models/gaussian_covar/d50_explicit_pkl', 'errors_df.pkl')\n",
    "#errors_df.to_pickle(file_path)\n",
    "#load \n",
    "#file_path_load = 'models/gaussian_covar/d50_explicit_pkl/errors_df.pkl'\n",
    "#errors_df = pd.read_pickle(file_path_load)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3c2e1e69-5db3-49c7-92b9-47716157e4f0",
   "metadata": {},
   "outputs": [],
   "source": [
    "results_df = pd.merge(simulation_df, errors_df, on=\"a\", how=\"left\")\n",
    "\n",
    "# comutation of the upperbound KL\n",
    "error_mixing_E1 = results_df.groupby('a')['error_mixing_E1'].first()\n",
    "error_approx_E2 = results_df.groupby('a')['error_approx_E2'].mean()\n",
    "error_discr_E3 = results_df.groupby('a')['error_discr_E3'].first() \n",
    "KL_upperbound = error_mixing_E1 + error_approx_E2 + error_discr_E3\n",
    "\n",
    "# computation of the W2 upperbound\n",
    "W2_upperbound = results_df.groupby('a')['w2_error'].mean()\n",
    "\n",
    "# generation results KL \n",
    "mean_all_Euler_KL = results_df[results_df['scheme'] == 'euler'].groupby('a')['kl'].mean()\n",
    "std_all_Euler_KL = results_df[results_df['scheme'] == 'euler'].groupby('a')['kl'].std()\n",
    "mean_all_EI_KL = results_df[results_df['scheme'] == 'semii'].groupby('a')['kl'].mean()\n",
    "std_all_EI_KL = results_df[results_df['scheme'] == 'semii'].groupby('a')['kl'].std()\n",
    "\n",
    "# generation results W2 \n",
    "mean_all_Euler_W2 = results_df[results_df['scheme'] == 'euler'].groupby('a')['w2'].mean()\n",
    "std_all_Euler_W2 = results_df[results_df['scheme'] == 'euler'].groupby('a')['w2'].std()\n",
    "mean_all_EI_W2 = results_df[results_df['scheme'] == 'semii'].groupby('a')['w2'].mean()\n",
    "std_all_EI_W2 = results_df[results_df['scheme'] == 'semii'].groupby('a')['w2'].std()\n",
    "\n",
    "# generation results VPSDE KL\n",
    "\n",
    "VPSDE_mean_KL = mean_all_Euler_KL[0.0]\n",
    "VPSDE_std_KL = std_all_Euler_KL[0.0]\n",
    "\n",
    "# generation results VPSDE W2\n",
    "\n",
    "VPSDE_mean_W2 = mean_all_Euler_W2[0.0]\n",
    "VPSDE_std_W2 = std_all_Euler_W2[0.0]\n",
    "\n",
    "# Exact score results KL \n",
    "true_score_KL_Euler = distances_df[(distances_df[\"scheme\"] == 'euler') & (distances_df[\"loss\"] == \"exact_score\") & (distances_df[\"num_steps\"]==num_steps)]['kl']\n",
    "true_score_KL_EI = distances_df[(distances_df[\"scheme\"] == 'semii') & (distances_df[\"loss\"] == \"exact_score\") & (distances_df[\"num_steps\"]==num_steps)]['kl']\n",
    "\n",
    "# Exact score results W2 \n",
    "true_score_W2_Euler = distances_df[(distances_df[\"scheme\"] == 'euler') & (distances_df[\"loss\"] == \"exact_score\") & (distances_df[\"num_steps\"]==num_steps)]['w2']\n",
    "true_score_W2_EI = distances_df[(distances_df[\"scheme\"] == 'semii') & (distances_df[\"loss\"] == \"exact_score\") & (distances_df[\"num_steps\"]==num_steps)]['w2']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e3665a2c-d3e7-4eb0-aff0-a1ddcde7fc5c",
   "metadata": {},
   "outputs": [],
   "source": [
    "##### PLOT KL VS##### PLOT KL VS BOUND\n",
    "font = {'family' : 'sans-serif',\n",
    "        'size'   : 16}\n",
    "plt.rc('font', **font)\n",
    "cmap = cm.get_cmap('tab20c')\n",
    "mpl.rcParams['axes.facecolor'] = 'white'\n",
    "mpl.rcParams['axes.grid'] = False\n",
    "mpl.rcParams['axes.edgecolor'] = 'gray'\n",
    "\n",
    "fig, ax1 = plt.subplots()\n",
    "\n",
    "y1_color = cmap(1/20)\n",
    "ax1.plot(a_values, KL_upperbound, color=y1_color, label=\"Upper-bound\",linewidth=3)\n",
    "ax1.fill_between(a_values,\n",
    "                 ([x.item() for x in KL_upperbound]) - results_df.groupby('a')['error_approx_E2'].std().to_numpy() ,\n",
    "                 ([x.item() for x in KL_upperbound]) + results_df.groupby('a')['error_approx_E2'].std().to_numpy() ,\n",
    "                 fc=y1_color,alpha=.3)\n",
    "ax1.set_xlabel(\"Values of a\")\n",
    "ax1.set_ylabel(\"Upper bound (KL)\", color=y1_color)\n",
    "ax1.tick_params(axis='y', labelcolor=y1_color)\n",
    "\n",
    "y2_color = cmap(5/20)\n",
    "y3_color = cmap(5/20)\n",
    "color_0 = 'gray' #mean_all_Euler_KL\n",
    "color_bis = 'black'\n",
    "ax2 = ax1.twinx()\n",
    "\n",
    "ax2.plot(a_values, mean_all_Euler_KL, color=y2_color, label=\"KL($\\\\pi_{\\\\rm data}$,$\\\\hat{\\\\pi}$) (NN) \",linewidth=3)\n",
    "ax2.plot(a_values, true_score_KL_Euler, color=y2_color, label = \"KL($\\\\pi_{\\\\rm data}$,$\\\\hat{\\pi}$) (exact score)\",linewidth=3, linestyle='dotted')\n",
    "ax2.fill_between(a_values,\n",
    "                 mean_all_Euler_KL-std_all_Euler_KL,\n",
    "                 mean_all_Euler_KL+std_all_Euler_KL,\n",
    "                 fc=y2_color,alpha=.3)\n",
    "ax2.plot(a_values, np.full_like(a_values, VPSDE_mean_KL), color=y3_color, label=\"VPSDE (NN)\",linewidth=2, linestyle='--')\n",
    "ax2.fill_between(a_values,\n",
    "                 VPSDE_mean_KL-VPSDE_std_KL,\n",
    "                 VPSDE_mean_KL+VPSDE_std_KL,\n",
    "                 fc=y3_color,alpha=.2)\n",
    "ax2.set_ylabel(\"KL divergence\", color=y2_color)\n",
    "ax2.tick_params(axis='y', labelcolor=y2_color)\n",
    "\n",
    "fig.tight_layout()\n",
    "plt.legend()\n",
    "#plt.savefig(\"fig/UB_and_KL_EI_classique_covar_d\"+str(d)+\".pdf\")\n",
    "plt.show() "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e19af2f9-bda6-468f-89d4-a96249297986",
   "metadata": {},
   "outputs": [],
   "source": [
    "##### PLOT WASSERSTEIN ########\n",
    "font = {'family' : 'sans-serif',\n",
    "        'size'   : 16}\n",
    "plt.rc('font', **font)\n",
    "cmap = cm.get_cmap('tab20c')\n",
    "mpl.rcParams['axes.facecolor'] = 'white'\n",
    "mpl.rcParams['axes.grid'] = False\n",
    "mpl.rcParams['axes.edgecolor'] = 'gray'\n",
    "\n",
    "fig, ax1 = plt.subplots()\n",
    "\n",
    "y1_color = cmap(1/20)\n",
    "#ax1.plot(a_values, W2_upperbound, color=y1_color, label=\"Upper-bound\",linewidth=3)\n",
    "#ax1.fill_between(a_values,\n",
    "#                 ([x for x in W2_upperbound]) - simulation_df.groupby('a')['error_approx_sup_L2'].std().to_numpy() ,\n",
    "#                 ([x for x in W2_upperbound]) + simulation_df.groupby('a')['error_approx_sup_L2'].std().to_numpy() ,\n",
    "#                 fc=y1_color,alpha=.3)\n",
    "ax1.set_xlabel(\"Values of a\")\n",
    "ax1.set_ylabel(\"Upper bound (W2)\", color=y1_color)\n",
    "ax1.tick_params(axis='y', labelcolor=y1_color)\n",
    "\n",
    "y2_color = cmap(5/20)\n",
    "y3_color = cmap(5/20)\n",
    "color_0 = 'gray' #mean_all_Euler_KL\n",
    "color_bis = 'black'\n",
    "ax2 = ax1.twinx()\n",
    "\n",
    "ax2.plot(a_values, mean_all_Euler_W2, color=y2_color, label= r\"$\\mathcal{W}_2(\\pi_{\\rm data},\\hat{\\pi})$ (NN) \",linewidth=3)\n",
    "ax2.plot(a_values, true_score_W2_Euler, color=y2_color, label = r\"$\\mathcal{W}_2(\\pi_{\\rm data}$,$\\hat{\\pi}$) (exact score)\",linewidth=3, linestyle='dotted')\n",
    "ax2.fill_between(a_values,\n",
    "                 mean_all_Euler_W2-std_all_Euler_W2,\n",
    "                 mean_all_Euler_W2+std_all_Euler_W2,\n",
    "                 fc=y2_color,alpha=.3)\n",
    "ax2.plot(a_values, np.full_like(a_values, VPSDE_mean_W2), color=y3_color, label=\"VPSDE (NN)\",linewidth=2, linestyle='--')\n",
    "ax2.fill_between(a_values,\n",
    "                 VPSDE_mean_W2-VPSDE_std_W2,\n",
    "                 VPSDE_mean_W2+VPSDE_std_W2,\n",
    "                 fc=y3_color,alpha=.2)\n",
    "ax2.set_ylabel(r\"$\\mathcal{W}_2$ distance\", color=y2_color)\n",
    "ax2.tick_params(axis='y', labelcolor=y2_color)\n",
    "\n",
    "fig.tight_layout()\n",
    "\n",
    "plt.legend(loc = 'upper center')\n",
    "#plt.savefig(\"covariance_rescale_num_steps_500_W_2_C_t+L_t_corr_M1.pdf\")\n",
    "plt.show() "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "515874f0-4beb-471b-8621-c9d999087921",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ce8f60ea-638d-4e1a-81b7-4075a731b30c",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f2e49b88-dd94-45b1-8172-4ba477e5c918",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fb59b820-3e9b-41a6-9fdc-7416f7080fd5",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4ded206e-755f-430b-bdc2-24c2a04c3585",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "6c8d7415-5089-4014-af2d-9bbd5e359188",
   "metadata": {},
   "source": [
    "### COVAR RESCALE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8c619612-c2e7-44b1-88d7-814e0c4c1fe4",
   "metadata": {},
   "outputs": [],
   "source": [
    "########### SET Training data PARAMETERS #################\n",
    "n = 10000\n",
    "d = 50\n",
    "mu = torch.ones(d, device = device)\n",
    "\n",
    "### Correlated case\n",
    "SIGMA = torch.eye(d, device=device)\n",
    "for i in range(d):\n",
    "        for j in range(d):\n",
    "            if i != j:\n",
    "                SIGMA[i, j] = 1 / np.sqrt(np.abs(j-i)+1)\n",
    "\n",
    "\n",
    "training_distribution = func.gaussian(d, mu, SIGMA)\n",
    "training_sample = training_distribution.generate_sample(n)\n",
    "\n",
    "########### SET DIFFUSION AND NOISING PARAMETERS #################\n",
    "sigma_inv = 1\n",
    "beta_min = 0.1\n",
    "beta_max = 20\n",
    "a = 0.\n",
    "T = 1\n",
    "\n",
    "# Set noising function and parameters\n",
    "a_values = np.linspace(-10.,10.,21) #set the parameter values for the beta_parametric function\n",
    "beta = func.beta_parametric(a, T, beta_min, beta_max)\n",
    "#beta = func.beta_cosine(T, 0.003)\n",
    "\n",
    "sde = diff.forward_VPSDE(d, beta, sigma_inv, T, device = device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "67daa51d-6e98-4e23-9c1d-f73b3754e855",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "rescale = np.sqrt(16) #set diag to 1/rescale\n",
    "training_sample_rescale = func.normalize(training_sample, rescale)[0]\n",
    "func.compute_cov_matrix(training_sample_rescale)\n",
    "cov = func.compute_cov_matrix(training_sample_rescale)\n",
    "eigenvalues = torch.linalg.eigvals(cov)\n",
    "print(torch.min(torch.abs(eigenvalues)), torch.max(torch.abs(eigenvalues)))\n",
    "\n",
    "a=func.normalize(training_sample, rescale)[1]\n",
    "b= func.normalize(training_sample, rescale)[2]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "011dc917-134b-4ba0-9f1f-8cffb95e304f",
   "metadata": {},
   "outputs": [],
   "source": [
    "SIGMA"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ca08522a-1973-4343-a91b-37775fde838a",
   "metadata": {},
   "outputs": [],
   "source": [
    "'''\n",
    "rescale = np.sqrt(16) #set diag to 1/rescale\n",
    "training_sample_rescale = training_sample/np.sqrt(16)\n",
    "func.compute_cov_matrix(training_sample_rescale)\n",
    "cov = func.compute_cov_matrix(training_sample_rescale)\n",
    "eigenvalues = torch.linalg.eigvals(cov)\n",
    "print(torch.min(torch.abs(eigenvalues)), torch.max(torch.abs(eigenvalues)))\n",
    "'''"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "64f2c1c7-20a8-4980-8384-0852fc3c78f1",
   "metadata": {},
   "outputs": [],
   "source": [
    "cov = SIGMA/(rescale)**2 #func.compute_cov_matrix(training_sample_rescale)\n",
    "mean = torch.ones(d)/4 #torch.mean(training_sample_rescale, dim=0) *0\n",
    "dist =  gaussian(sde.d,mean, cov)\n",
    "\n",
    "cov_emp = func.compute_cov_matrix(training_sample_rescale)\n",
    "mean_emp = torch.mean(training_sample_rescale, dim=0) \n",
    "dist_emp = gaussian(sde.d,mean_emp, cov_emp)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b0e92af2-32f7-423b-aeff-86ad65d71ae0",
   "metadata": {},
   "outputs": [],
   "source": [
    "num_steps = 500"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4ffbdbbf-95dd-42a4-bde9-57bef2ba3a27",
   "metadata": {},
   "outputs": [],
   "source": [
    "#epsilon = simulation_df.groupby('a')['error_approx_sup_L2'].mean()*0\n",
    "ellbar = []\n",
    "ellbar3 = []\n",
    "M1 = []\n",
    "\n",
    "for k, a in tqdm(enumerate(a_values)):\n",
    "    print(f\"bound for a = {a}\")\n",
    "    sde.beta.change_a(a)\n",
    "    #ellbar_emp.append(func.compute_ellbar_empirical(training_sample_rescale, sde, num_steps))\n",
    "    ellbar.append(func.compute_ellbar(dist, training_sample_rescale, sde, num_steps))\n",
    "    #ellbar3.append(func.compute_ellbar3(dist_emp, training_sample_rescale, sde, num_steps))\n",
    "    #M1.append(func.compute_M1_new(dist,sde,num_steps))\n",
    "              \n",
    "plt.plot(a_values, ellbar, label = 'ell bar')\n",
    "plt.plot(a_values, ellbar3, label = 'ell bar 3')\n",
    "#plt.plot(a_values, M1, label = 'M1')\n",
    "plt.legend()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3f16fb95-97a6-485d-b835-672d16379cc6",
   "metadata": {},
   "outputs": [],
   "source": [
    "### Set up de base\n",
    "a = -10\n",
    "sde.beta.change_a(a)\n",
    "#rec = compute_ellbar(dist, training_sample_rescale, sde, num_steps)[1]\n",
    "#plt.plot(rec)\n",
    "k1 = []\n",
    "k2 = []\n",
    "#lambda_min = 0.01\n",
    "#lamnda_max= 1.\n",
    "times = torch.linspace(0, sde.final_time, num_steps+1) \n",
    "for i in range(len(times) - 1):\n",
    "    tk = torch.tensor(times[i], device = sde.device)\n",
    "    tkp1 = torch.tensor(times[i+1], device = sde.device)\n",
    "    eigen = torch.abs(torch.linalg.eigvals(dist._sigma))\n",
    "    lambda_min = torch.min(eigen)\n",
    "    lambda_max = torch.max(eigen)\n",
    "    k_1 = (1 /(sde.mu(tkp1) * lambda_min + sde.sigma(tk)**2))**2 \\\n",
    "              * (sde.beta(torch.tensor(tkp1)) * sde.mu(tk)**2 * ( lambda_max/ sde.sigma_infty**2 + 1))\n",
    "    k1.append(k_1)\n",
    "    norm_mu = torch.norm(torch.mean(training_sample, axis = 0),p=2)\n",
    "    k2.append( norm_mu * (sde.beta(torch.tensor(tkp1)) * sde.mu(tk) )/(2 * sde.sigma_infty**2) * \\\n",
    "              (1 / ( sde.mu(tkp1)**2 * lambda_min + sde.sigma(tk)**2) ) + sde.mu(tk)* k_1 )\n",
    "    \n",
    "    \n",
    "    \n",
    "plt.plot( k1, label = 'k1')\n",
    "plt.plot(k1_new, label = 'k2')   \n",
    "plt.legend()\n",
    "    \n",
    "    \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e51a3384-fe9a-43a0-8024-349f3d937dc3",
   "metadata": {},
   "outputs": [],
   "source": [
    "### Après modif\n",
    "a = 2\n",
    "sde.beta.change_a(a)\n",
    "#rec = compute_ellbar(dist, training_sample_rescale, sde, num_steps)[1]\n",
    "#plt.plot(rec)\n",
    "k1 = []\n",
    "k1_new = []\n",
    "k2 = []\n",
    "#lambda_min = 0.01\n",
    "#lamnda_max= 1.\n",
    "times = torch.linspace(0, sde.final_time, num_steps+1) \n",
    "for i in range(len(times) - 1):\n",
    "    tk = torch.tensor(times[i], device = sde.device)\n",
    "    tkp1 = torch.tensor(times[i+1], device = sde.device)\n",
    "    eigen = torch.abs(torch.linalg.eigvals(dist._sigma))\n",
    "    lambda_min = torch.min(eigen)\n",
    "    lambda_max = torch.max(eigen)\n",
    "    k_1 = (1 /(sde.mu(tkp1) * lambda_min + sde.sigma(tk)**2))**2 \\\n",
    "              * (sde.beta(torch.tensor(tkp1)) * sde.mu(tk)**2 * ( lambda_max/ sde.sigma_infty**2 + 1))\n",
    "    k1.append(k_1)\n",
    "    k_1_over =  sde.beta(torch.tensor(tkp1))/sde.sigma_infty**2 * torch.abs(lambda_min - sde.sigma_infty**2)\n",
    "    k_1_under = torch.abs( (sde.sigma_infty**2 + sde.mu(tk)**2 * (lambda_min - sde.sigma_infty**2)) \\\n",
    "                          * (sde.sigma_infty**2 + sde.mu(tkp1)**2 * (lambda_min - sde.sigma_infty**2)))\n",
    "    k1_new.append(k_1_over/k_1_under)\n",
    "    norm_mu = torch.norm(torch.mean(training_sample, axis = 0),p=2)\n",
    "    k2.append( norm_mu * (sde.beta(torch.tensor(tkp1)) * sde.mu(tk) )/(2 * sde.sigma_infty**2) * \\\n",
    "              (1 / ( sde.mu(tkp1)**2 * lambda_min + sde.sigma(tk)**2) ) + sde.mu(tk)* k_1 )\n",
    "    \n",
    "    \n",
    "    \n",
    "plt.plot( k1, label = 'k1')\n",
    "plt.plot(k1_new, label = 'k1_new')   \n",
    "plt.legend()\n",
    "    \n",
    "    \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ca0daeb8-6575-43d5-9c4c-8ebe2fe5da8e",
   "metadata": {},
   "outputs": [],
   "source": [
    "### Après modif\n",
    "a = 2\n",
    "sde.beta.change_a(a)\n",
    "#rec = compute_ellbar(dist, training_sample_rescale, sde, num_steps)[1]\n",
    "#plt.plot(rec)\n",
    "k1_new = []\n",
    "k2 = []\n",
    "#lambda_min = 0.01\n",
    "#lamnda_max= 1.\n",
    "times = torch.linspace(0, sde.final_time, num_steps+1) \n",
    "for i in range(len(times) - 1):\n",
    "    tk = torch.tensor(times[i], device = sde.device)\n",
    "    tkp1 = torch.tensor(times[i+1], device = sde.device)\n",
    "    eigen = torch.abs(torch.linalg.eigvals(dist._sigma))\n",
    "    lambda_min = torch.min(eigen)\n",
    "    lambda_max = torch.max(eigen)\n",
    "    k_1_over =  sde.beta(torch.tensor(tkp1))/sde.sigma_infty**2 * torch.abs(lambda_min - sde.sigma_infty**2)\n",
    "    k_1_under = torch.abs( (sde.sigma_infty**2 + sde.mu(tk)**2 * (lambda_min - sde.sigma_infty**2)) \\\n",
    "                          * (sde.sigma_infty**2 + sde.mu(tkp1)**2 * (lambda_min - sde.sigma_infty**2)))\n",
    "    k1_new.append(k_1_over/k_1_under)\n",
    "    norm_mu = torch.norm(torch.mean(training_sample, axis = 0),p=2)\n",
    "    k2.append( norm_mu * (sde.beta(torch.tensor(tkp1)) * sde.mu(tk) )/(2 * sde.sigma_infty**2) * \\\n",
    "              (1 / ( sde.mu(tkp1)**2 * lambda_min + sde.sigma(tk)**2) ) + sde.mu(tk)* k_1 )\n",
    "    \n",
    "    \n",
    "    \n",
    "plt.plot( k2, label = 'k2')\n",
    "plt.plot(k1_new, label = 'k1_new')   \n",
    "plt.legend()\n",
    "    \n",
    "    \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e04a910b-d840-4b28-9892-8b413c79824f",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "050bc8ba-bced-4e64-ad01-ecf969e069bb",
   "metadata": {},
   "outputs": [],
   "source": [
    "sde.sigma(torch.tensor(0.01))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b2af6553-3867-4b85-a9f8-e6825fa75fd3",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "48e9d252-6630-4d83-bcb7-ca6a8bf766c8",
   "metadata": {},
   "outputs": [],
   "source": [
    "k1 = []\n",
    "times = torch.linspace(0, sde.final_time, num_steps+1) \n",
    "for i in range(len(times) - 1):\n",
    "    tk = torch.tensor(times[i], device = sde.device)\n",
    "    tkp1 = torch.tensor(times[i+1], device = sde.device)\n",
    "    eigen = torch.abs(torch.linalg.eigvals(dist._sigma))\n",
    "    lambda_min = torch.min(eigen)\n",
    "    lambda_max = torch.max(eigen)\n",
    "    #k_1 = (1 /(sde.mu(tkp1) * lambda_min + sde.sigma(tk)**2))**2 \\\n",
    "    k_1 =      (sde.beta(torch.tensor(tkp1)) * sde.mu(tk)**2 * ( lambda_max/ sde.sigma_infty**2 + 1))\n",
    "    k1.append(k_1)\n",
    "\n",
    "plt.plot(k1)\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "624f509d-e3ea-43c2-87b1-03902200896b",
   "metadata": {},
   "outputs": [],
   "source": [
    "lambda_min"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f5551de0-64b3-4dd4-b46d-80f71e550654",
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_ellbar(dataset, training_sample, sde, num_steps):\n",
    "    times = torch.linspace(0, sde.final_time, num_steps+1) \n",
    "    hist= []\n",
    "\n",
    "    for i in range(len(times) - 1):\n",
    "        tk = torch.tensor(times[i], device = sde.device)\n",
    "        tkp1 = torch.tensor(times[i+1], device = sde.device)\n",
    "        eigen = torch.abs(torch.linalg.eigvals(dataset._sigma))\n",
    "        lambda_min = torch.min(eigen)\n",
    "        lambda_max = torch.max(eigen)\n",
    "        kappa_1 = (1 /(sde.mu(tkp1) * lambda_min + sde.sigma(tk)**2))**2 * (sde.beta(torch.tensor(tkp1)) * sde.mu(tk)**2 * ( lambda_max/ sde.sigma_infty**2 + 1))\n",
    "        norm_mu = torch.norm(torch.mean(training_sample, axis = 0),p=2)\n",
    "\n",
    "        #kappa_2 = np.sqrt(sde.d) * (sde.beta(torch.tensor(tkp1)) * sde.mu(tk) )/(2 * sde.sigma_infty**2) * ( 1 / ( sde.mu(tkp1)**2 * lambda_min + sde.sigma(tk)**2) ) + sde.mu(tk)* kappa_1  \n",
    "        kappa_2 = norm_mu * (sde.beta(torch.tensor(tkp1)) * sde.mu(tk) )/(2 * sde.sigma_infty**2) * ( 1 / ( sde.mu(tkp1)**2 * lambda_min + sde.sigma(tk)**2) ) + sde.mu(tk)* kappa_1 \n",
    "        hist.append( np.max([kappa_1,kappa_2]))                                                               \n",
    "    return np.max(hist), hist"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "952082b5-81ad-4ba7-8426-c43bcbcd3d56",
   "metadata": {},
   "outputs": [],
   "source": [
    "################################# SET SAMPLING PARAMETERS #################################\n",
    "sample_batch_size = 10000 # size of the sample generated\n",
    "init = sde.final.generate_sample(sample_batch_size)\n",
    "num_steps = 500 \n",
    "#xbarT_euler = sp.Euler_Maruyama_discr_sampler(init, sde, diff.explicit_score(sde, training_distribution), num_steps)\n",
    "#xbarT_semii = sp.EI_discr_sampler(init, sde, diff.explicit_score(sde, training_distribution), num_steps)\n",
    "#num_steps_vec = np.array([num_steps//4, num_steps//2, num_steps ]) #to adjust as desired"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fcdbe300-dfb0-42d8-b273-8a8a47e9d589",
   "metadata": {},
   "outputs": [],
   "source": [
    "#a=func.normalize(training_sample, rescale)[1]\n",
    "#a= torch.ones(d)\n",
    "#b= func.normalize(training_sample, rescale)[2]\n",
    "#b = diag_values\n",
    "##### EXACT SCORE SIMULATION GAUSSIAN CASE ####\n",
    "num_steps_vec = np.array([num_steps]) #to adjust as desired\n",
    "distances = []\n",
    "for i, ns in tqdm(enumerate(num_steps_vec)):\n",
    "    for j, a in enumerate(a_values):\n",
    "        sde.beta.change_a(a)\n",
    "        score_theta = diff.explicit_score(sde, dist_emp)\n",
    "        for k, scheme in enumerate([sp.Euler_Maruyama_discr_sampler]):#, sp.EI_discr_sampler]):\n",
    "            sample = scheme(init, sde, score_theta, ns)\n",
    "            sample = func.unnormalize(sample, a,b, rescale)\n",
    "            #sample = sample * 4\n",
    "            distances.append({\n",
    "                \"a\": a, \"num_steps\": ns, \n",
    "                \"scheme\": \"euler\" if k==0 else \"semii\", \n",
    "                \"loss\": \"exact_score\", \n",
    "                #\"kl\": kl(training_distribution, empirical(func.unnormalize(sample, a,b, rescale))), \n",
    "                \"w2\": w2(training_distribution, empirical(sample)),\n",
    "            })\n",
    "distances_df = pd.DataFrame(distances)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d2f4fc44-17c9-4178-b4e0-9f45f1763b2c",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axs = plt.subplots(figsize=(8,4), ncols=2, sharey=True, sharex = True)\n",
    "axs[0].set_title(\"W2 distance, exact score, Euler scheme\")\n",
    "axs[1].set_title(\"W2, exact score, Semi-Implicit scheme\")\n",
    "\n",
    "for k, scheme in enumerate([\"euler\", \"semii\"]):\n",
    "    for ns in num_steps_vec:\n",
    "        select = (distances_df[\"scheme\"] == scheme) & (distances_df[\"loss\"] == \"exact_score\") & (distances_df[\"num_steps\"]==ns)\n",
    "        res_kl = distances_df[select].groupby(\"a\")[\"w2\"].mean()\n",
    "        axs[k].plot(a_values, res_kl, label=f\"num_steps = {ns}\")\n",
    "    axs[k].legend()\n",
    "    axs[k].set_xlabel(\"a value\")\n",
    "\n",
    "axs[0].set_ylabel(\"W2 distance\")\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e312e5c8-c1e7-403a-b2fb-d3e8aeaba904",
   "metadata": {},
   "outputs": [],
   "source": [
    "######## SELECT TRAINING PARAMETERS W.R.T THE DIMENSION ######################\n",
    "if d==5: \n",
    "    # optimisation parameters\n",
    "    n_epochs = 2 \n",
    "    learning_rate = 1.0e-4\n",
    "    \n",
    "elif d==10: \n",
    "    # optimisation parameters\n",
    "    n_epochs = 30\n",
    "    learning_rate = 1.0e-4  \n",
    "    \n",
    "elif d==25: \n",
    "    # optimisation parameters\n",
    "    n_epochs = 50\n",
    "    learning_rate = 1.0e-3\n",
    "    \n",
    "elif d==50: \n",
    "    # optimisation parameters\n",
    "    n_epochs = 150\n",
    "    learning_rate = 1.0e-3\n",
    "\n",
    "# network_parameters\n",
    "mid_features = 256\n",
    "num_layers = 3\n",
    "batch_size = 64\n",
    "\n",
    "# Monte Carlo estimate samples (for E2 computation in the KL bound)\n",
    "num_mc = 500\n",
    "\n",
    "# number of runs (for replication)\n",
    "rep_size = 3\n",
    "\n",
    "# load data\n",
    "dataloader = DataLoader(training_sample_rescale, batch_size=batch_size, shuffle=True)\n",
    "\n",
    "# number of discretisation steps\n",
    "num_step = 150"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "866f3a7f-1a64-4b20-96a5-89966f51caff",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "###\n",
    "## TRAINING NN FOR DIFFERENT NOISE SCHEDULES\n",
    "###\n",
    "\n",
    "score_theta_trained = [ [] for _ in a_values ]   # a_values, replicates\n",
    "for k, a in enumerate(a_values):\n",
    "    print(f\"optimization for a = {a}\")\n",
    "    sde.beta.change_a(a)\n",
    "    score_theta_trained[k] = [ ]\n",
    "    for j in range(rep_size):\n",
    "        network = decoder.Decoder(sde.d, mid_features, num_layers).to(device)\n",
    "        optimizer = Adam(network.parameters(), lr=learning_rate)\n",
    "        #loss = diff.loss_conditional(network, sde) # to choose when the score function is not analytically available\n",
    "        loss = diff.loss_explicit(network, sde, diff.explicit_score(sde, dist)) # to choose in the Gaussian case for training with the analytical score function\n",
    "        diff.train(loss, dataloader, n_epochs=n_epochs, optimizer=optimizer)\n",
    "        score_theta_trained[k].append(network)\n",
    "        #torch.save(network.state_dict(), f'models/gaussian_covar/d50_explicit_rescale_standardize_ep150'+f'/model_{k}_{j}_d'+str(d)+'.pt')\n",
    "\n",
    "'''\n",
    "##### TO LOAD PRETRAINED MODELS\n",
    "score_theta_trained = [ [] for _ in a_values ]   # a_values, replicates\n",
    "for k, a in enumerate(a_values):\n",
    "    score_theta_trained[k] = [ ]\n",
    "    for j in range(rep_size):\n",
    "        network = decoder.Decoder(sde.d, mid_features, num_layers) \n",
    "        network.load_state_dict(torch.load(f'models/gaussian_covar/d50_explicit_rescale_ep150'+f'/model_{k}_{j}_d'+str(d)+'.pt'))\n",
    "        score_theta_trained[k].append(network)\n",
    "'''\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "08a459eb-9bce-482b-a5d9-9082a2491ea7",
   "metadata": {},
   "outputs": [],
   "source": [
    "### GAUSSIAN CASE\n",
    "\n",
    "simulation_results = []\n",
    "\n",
    "for k, a in tqdm(enumerate(a_values)):\n",
    "    print(f\"simulation and M for a = {a}\")\n",
    "    sde.beta.change_a(a)\n",
    "    exp_score = diff.explicit_score(sde, dist_emp)\n",
    "\n",
    "    for j, score_theta in enumerate(score_theta_trained[k]):\n",
    "        _, error_approx_sup_L2 = func.compute_E2(dist, sde, score_theta, exp_score, num_steps, num_mc)\n",
    "\n",
    "        for scheme_name, scheme_sampler in [(\"euler\", sp.Euler_Maruyama_discr_sampler)]:#, (\"semii\", sp.EI_discr_sampler)]:\n",
    "            sample = scheme_sampler(init, sde, score_theta, num_steps)\n",
    "            #kl_value = kl(training_distribution, empirical(func.unnormalize(sample,func.normalize(training_sample, rescale)[1], func.normalize(training_sample, rescale)[2],rescale)))\n",
    "            #w2_value = w2(training_distribution, empirical(func.unnormalize(sample, a, b ,rescale)))\n",
    "            w2_value = w2(training_distribution, empirical(sample*4))\n",
    "\n",
    "            simulation_results.append({\n",
    "                \"a\": a,\n",
    "                \"replication\": j,\n",
    "                \"scheme\": scheme_name,\n",
    "                #\"error_approx_E2\": error_approx_E2,\n",
    "                \"error_approx_sup_L2\": error_approx_sup_L2,\n",
    "                #\"kl\": kl_value,\n",
    "                \"w2\": w2_value,\n",
    "            })\n",
    "\n",
    "simulation_df = pd.DataFrame(simulation_results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2deebf5d-fb89-49ef-a2a8-6c08b5c712ea",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c78ecd99-4f80-436e-8329-2d68a8831630",
   "metadata": {},
   "outputs": [],
   "source": [
    "epsilon = simulation_df.groupby('a')['error_approx_sup_L2'].mean()\n",
    "error_results = []\n",
    "\n",
    "for k, a in tqdm(enumerate(a_values)):\n",
    "    print(f\"bound for a = {a}\")\n",
    "    sde.beta.change_a(a)\n",
    "    #error_mixing_E1 = func.compute_E1(training_distribution, sde)\n",
    "    #error_discr_E3 = func.compute_E3(training_distribution, sde, num_steps)\n",
    "    w2_error = func.compute_w2_bound(dist, training_sample_rescale, sde, num_steps, epsilon[a])\n",
    "    #w2_error2 = func.compute_w2_bound2(dist_emp, training_sample_rescale, sde, num_steps, epsilon[a])\n",
    "\n",
    "    #w2_error_emp = func.compute_w2_bound_empirical(training_sample_rescale, sde, num_steps, epsilon[a])\n",
    "\n",
    "    error_results.append({\n",
    "        \"a\": a,\n",
    "        #\"error_mixing_E1\": error_mixing_E1,\n",
    "        #\"error_discr_E3\": error_discr_E3,\n",
    "        #\"w2_error_2\": w2_error2,\n",
    "        \"w2_error\": w2_error,\n",
    "\n",
    "    })\n",
    "\n",
    "errors_df = pd.DataFrame(error_results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ade2a3a8-82fb-4750-aabd-16af10fcd455",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a594ef03-dcbb-4733-a38a-ee01af6cba78",
   "metadata": {},
   "outputs": [],
   "source": [
    "results_df = pd.merge(simulation_df, errors_df, on=\"a\", how=\"left\")\n",
    "\n",
    "# computation of the W2 upperbound\n",
    "W2_upperbound = results_df.groupby('a')['w2_error'].mean() #* 1/rescale\n",
    "\n",
    "# generation results W2 \n",
    "mean_all_Euler_W2 = results_df[results_df['scheme'] == 'euler'].groupby('a')['w2'].mean()\n",
    "std_all_Euler_W2 = results_df[results_df['scheme'] == 'euler'].groupby('a')['w2'].std()\n",
    "mean_all_EI_W2 = results_df[results_df['scheme'] == 'semii'].groupby('a')['w2'].mean()\n",
    "std_all_EI_W2 = results_df[results_df['scheme'] == 'semii'].groupby('a')['w2'].std()\n",
    "\n",
    "# generation results VPSDE W2\n",
    "\n",
    "VPSDE_mean_W2 = mean_all_Euler_W2[0.0]\n",
    "VPSDE_std_W2 = std_all_Euler_W2[0.0]\n",
    "\n",
    "# Exact score results W2 \n",
    "true_score_W2_Euler = distances_df[(distances_df[\"scheme\"] == 'euler') & (distances_df[\"loss\"] == \"exact_score\") & (distances_df[\"num_steps\"]==num_steps)]['w2']\n",
    "true_score_W2_EI = distances_df[(distances_df[\"scheme\"] == 'semii') & (distances_df[\"loss\"] == \"exact_score\") & (distances_df[\"num_steps\"]==num_steps)]['w2']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dd58de05-0cd3-4ac9-8f41-67c4747ad625",
   "metadata": {},
   "outputs": [],
   "source": [
    "W2_upperbound"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "856f4646-38b8-49cf-ae05-8954edc5bdfd",
   "metadata": {},
   "outputs": [],
   "source": [
    "##### PLOT WASSERSTEIN ########\n",
    "font = {'family' : 'sans-serif',\n",
    "        'size'   : 16}\n",
    "plt.rc('font', **font)\n",
    "cmap = cm.get_cmap('tab20c')\n",
    "mpl.rcParams['axes.facecolor'] = 'white'\n",
    "mpl.rcParams['axes.grid'] = False\n",
    "mpl.rcParams['axes.edgecolor'] = 'gray'\n",
    "\n",
    "fig, ax1 = plt.subplots()\n",
    "\n",
    "y1_color = cmap(1/20)\n",
    "ax1.plot(a_values, W2_upperbound, color=y1_color, label=\"Upper-bound\",linewidth=3)\n",
    "ax1.fill_between(a_values,\n",
    "                 ([x for x in W2_upperbound]) - simulation_df.groupby('a')['error_approx_sup_L2'].std().to_numpy() ,\n",
    "                 ([x for x in W2_upperbound]) + simulation_df.groupby('a')['error_approx_sup_L2'].std().to_numpy() ,\n",
    "                 fc=y1_color,alpha=.3)\n",
    "ax1.set_xlabel(\"Values of a\")\n",
    "ax1.set_ylabel(\"Upper bound (W2)\", color=y1_color)\n",
    "ax1.tick_params(axis='y', labelcolor=y1_color)\n",
    "\n",
    "y2_color = cmap(5/20)\n",
    "y3_color = cmap(5/20)\n",
    "color_0 = 'gray' #mean_all_Euler_KL\n",
    "color_bis = 'black'\n",
    "ax2 = ax1.twinx()\n",
    "\n",
    "ax2.plot(a_values, mean_all_Euler_W2, color=y2_color, label= r\"$\\mathcal{W}_2(\\pi_{\\rm data},\\hat{\\pi})$ (NN) \",linewidth=3)\n",
    "ax2.plot(a_values, true_score_W2_Euler, color=y2_color, label = r\"$\\mathcal{W}_2(\\pi_{\\rm data}$,$\\hat{\\pi}$) (exact score)\",linewidth=3, linestyle='dotted')\n",
    "ax2.fill_between(a_values,\n",
    "                 mean_all_Euler_W2-std_all_Euler_W2,\n",
    "                 mean_all_Euler_W2+std_all_Euler_W2,\n",
    "                 fc=y2_color,alpha=.3)\n",
    "ax2.plot(a_values, np.full_like(a_values, VPSDE_mean_W2), color=y3_color, label=\"VPSDE (NN)\",linewidth=2, linestyle='--')\n",
    "ax2.fill_between(a_values,\n",
    "                 VPSDE_mean_W2-VPSDE_std_W2,\n",
    "                 VPSDE_mean_W2+VPSDE_std_W2,\n",
    "                 fc=y3_color,alpha=.2)\n",
    "ax2.set_ylabel(r\"$\\mathcal{W}_2$ distance\", color=y2_color)\n",
    "ax2.tick_params(axis='y', labelcolor=y2_color)\n",
    "\n",
    "fig.tight_layout()\n",
    "\n",
    "plt.legend(loc = 'center right')\n",
    "#plt.savefig(\"divlambdamax_COV.pdf\")\n",
    "plt.show() "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "189bc92e-d58a-47a6-b04c-17f998f725d1",
   "metadata": {},
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dcbe8473-8ac6-4f09-877f-ba733f0a7bd9",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "49291f1e-13ee-4cd9-b5f4-1c263f4d6230",
   "metadata": {},
   "source": [
    "### Correlated change sigma"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ab5b6027-496c-49bb-aeb7-0b5169f094b3",
   "metadata": {},
   "outputs": [],
   "source": [
    "########### SET Training data PARAMETERS #################\n",
    "n = 10000\n",
    "d = 50\n",
    "mu = torch.ones(d, device = device)\n",
    "\n",
    "### Correlated case\n",
    "SIGMA = torch.eye(d, device=device)\n",
    "for i in range(d):\n",
    "        for j in range(d):\n",
    "            if i != j:\n",
    "                SIGMA[i, j] = 1 / np.sqrt(np.abs(j-i)+1)\n",
    "\n",
    "\n",
    "training_distribution = func.gaussian(d, mu, SIGMA)\n",
    "training_sample = training_distribution.generate_sample(n)\n",
    "\n",
    "########### SET DIFFUSION AND NOISING PARAMETERS #################\n",
    "sigma_inv = 4.\n",
    "beta_min = 0.1\n",
    "beta_max = 20\n",
    "a = 0.\n",
    "T = 1\n",
    "\n",
    "# Set noising function and parameters\n",
    "a_values = np.linspace(-10.,10.,21) #set the parameter values for the beta_parametric function\n",
    "beta = func.beta_parametric(a, T, beta_min, beta_max)\n",
    "#beta = func.beta_cosine(T, 0.003)\n",
    "\n",
    "sde = diff.forward_VPSDE(d, beta, sigma_inv, T, device = device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3e61439b-9ef4-48d9-8cff-1172ff576471",
   "metadata": {},
   "outputs": [],
   "source": [
    "#epsilon = simulation_df.groupby('a')['error_approx_sup_L2'].mean()*0\n",
    "ellbar = []\n",
    "ellbar3 = []\n",
    "M1 = []\n",
    "\n",
    "for k, a in tqdm(enumerate(a_values)):\n",
    "    print(f\"bound for a = {a}\")\n",
    "    sde.beta.change_a(a)\n",
    "    #ellbar_emp.append(func.compute_ellbar_empirical(training_sample_rescale, sde, num_steps))\n",
    "    ellbar.append(func.compute_ellbar(dist, training_sample_rescale, sde, num_steps))\n",
    "    #ellbar3.append(func.compute_ellbar3(dist_emp, training_sample_rescale, sde, num_steps))\n",
    "    #M1.append(func.compute_M1_new(dist,sde,num_steps))\n",
    "              \n",
    "plt.plot(a_values, ellbar, label = 'ell bar')\n",
    "#plt.plot(a_values, ellbar3, label = 'ell bar 3')\n",
    "#plt.plot(a_values, M1, label = 'M1')\n",
    "plt.legend()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "279f330e-9f9a-4a40-9b91-8170434bdde0",
   "metadata": {},
   "outputs": [],
   "source": [
    "num_steps"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "de016af7-a644-4989-a716-631bb6e4820d",
   "metadata": {},
   "outputs": [],
   "source": [
    "epsilon = simulation_df.groupby('a')['error_approx_sup_L2'].mean()\n",
    "error_results = []\n",
    "\n",
    "for k, a in tqdm(enumerate(a_values)):\n",
    "    print(f\"bound for a = {a}\")\n",
    "    sde.beta.change_a(a)\n",
    "    #error_mixing_E1 = func.compute_E1(training_distribution, sde)\n",
    "    #error_discr_E3 = func.compute_E3(training_distribution, sde, num_steps)\n",
    "    w2_error = func.compute_w2_bound(dist, training_sample_rescale, sde, num_steps, epsilon[a])\n",
    "    #w2_error2 = func.compute_w2_bound2(dist_emp, training_sample_rescale, sde, num_steps, epsilon[a])\n",
    "\n",
    "    #w2_error_emp = func.compute_w2_bound_empirical(training_sample_rescale, sde, num_steps, epsilon[a])\n",
    "\n",
    "    error_results.append({\n",
    "        \"a\": a,\n",
    "        #\"error_mixing_E1\": error_mixing_E1,\n",
    "        #\"error_discr_E3\": error_discr_E3,\n",
    "        #\"w2_error_2\": w2_error2,\n",
    "        \"w2_error\": w2_error,\n",
    "\n",
    "    })\n",
    "\n",
    "errors_df = pd.DataFrame(error_results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ba539b7d-2b48-464b-a4f7-3e530855961b",
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_ellbar_improved(dataset, training_sample, sde, num_steps):\n",
    "    times = torch.linspace(0, sde.final_time, num_steps+1) \n",
    "    hist= []\n",
    "\n",
    "    for i in range(len(times) - 1):\n",
    "        tk = torch.tensor(times[i], device = sde.device)\n",
    "        tkp1 = torch.tensor(times[i+1], device = sde.device)\n",
    "        eigen = torch.abs(torch.linalg.eigvals(dataset._sigma))\n",
    "        lambda_min = torch.min(eigen)\n",
    "        lambda_max = torch.max(eigen)\n",
    "        kappa_1 = (1 /(sde.mu(tkp1) * lambda_min + sde.sigma(tk)**2))**2 * (sde.beta(torch.tensor(tkp1)) * sde.mu(tk)**2 * ( lambda_max/ sde.sigma_infty**2 + 1))\n",
    "        norm_mu = torch.norm(torch.mean(training_sample, axis = 0),p=2)\n",
    "\n",
    "        #kappa_2 = np.sqrt(sde.d) * (sde.beta(torch.tensor(tkp1)) * sde.mu(tk) )/(2 * sde.sigma_infty**2) * ( 1 / ( sde.mu(tkp1)**2 * lambda_min + sde.sigma(tk)**2) ) + sde.mu(tk)* kappa_1  \n",
    "        kappa_2 = norm_mu * (sde.beta(torch.tensor(tkp1)) * sde.mu(tk) )/(2 * sde.sigma_infty**2) * ( 1 / ( sde.mu(tkp1)**2 * lambda_min + sde.sigma(tk)**2) ) + sde.mu(tk)* kappa_1 \n",
    "        hist.append( np.max([kappa_1,kappa_2]))                                                               \n",
    "    return hist\n",
    "\n",
    "\n",
    "def compute_w2_bound_improved(dataset, training_sample, sde, num_steps, epsilon):\n",
    "\n",
    "    #constants computation\n",
    "    h = 1/num_steps\n",
    "    T = sde.final_time\n",
    "    B = np.sqrt(L2_norm_estimator(training_sample)**2 + sde.sigma_infty**2 * sde.d)\n",
    "    beta_final = sde.beta(torch.tensor(sde.final_time))\n",
    "    M_2 = np.sqrt(2*h* beta_final)/sde.sigma_infty + h*beta_final/(2* sde.sigma_infty**2)\n",
    "    ellbar = compute_ellbar_improved(dataset, training_sample, sde, num_steps)\n",
    "\n",
    "    #mixing\n",
    "    mixing =  compute_mixing_w2(dataset,sde)\n",
    "    t_points = torch.linspace(0, 1, steps=100)\n",
    "    ct_values = torch.tensor([compute_Ct(dataset, sde, t) * sde.beta(t) for t in t_points])\n",
    "    integral_approximation_Ct = torch.trapezoid(ct_values, t_points)\n",
    "    mixing *= torch.exp(- integral_approximation_Ct)\n",
    "\n",
    "    DELTA = 1.\n",
    "    DELTA_v = np.ones(num_steps)\n",
    "\n",
    "    #appprox+discr\n",
    "    aprox_discr = 0\n",
    "    times = torch.linspace(0, sde.final_time, num_steps+1) \n",
    "    for i in range(len(times) - 1):\n",
    "        rev_tk =  torch.tensor(times[i], device = sde.device) #T-tk\n",
    "        rev_tkp1 = torch.tensor(times[i+1], device = sde.device) #T-tkp1 \n",
    "        t_points = torch.linspace(rev_tkp1, rev_tk, steps=100)\n",
    "        Lt_beta_values = torch.tensor([compute_Lt(dataset, sde, t) * sde.beta(t) for t in t_points])\n",
    "        integral_approximation_Lt_beta = torch.trapezoid(Lt_beta_values, t_points)\n",
    "        aprox_discr += integral_approximation_Lt_beta * (M_2 + 2*integral_approximation_Lt_beta)*B\n",
    "\n",
    "    for i in range(len(times) - 1):\n",
    "        tkp1 =  torch.tensor(1.-times[i], device = sde.device) #T-tk\n",
    "        tk = torch.tensor(1.-times[i+1], device = sde.device) #T-tkp1 \n",
    "            \n",
    "        delta = (1 + .5 * sde.mu(1.-tk)**2/sde.mu(1.-tkp1)**2 * integral_approximation_Lt_beta_mt**2 \\\n",
    "                                                        - sde.mu(1.-tk)/sde.mu(1.-tkp1) * integral_approximation_Ct_beta_mt \\\n",
    "                                                        + sde.mu(1.-tk)**2/sde.mu(1.-tkp1)**2 * ellbar * h * integral_approximation_beta_mt)\n",
    "        DELTA *= delta\n",
    "        DELTA_v[num_steps- 1- i] = DELTA\n",
    "\n",
    "    \n",
    "    const_2 = epsilon* T * beta_final\n",
    "    const_3 = ellbar*h* T * beta_final* (1 + 2 * B ) \n",
    "    \n",
    "    return mixing + aprox_discr + const_2 + const_3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d9e7057e-d986-47f3-ae16-abecedebbdfb",
   "metadata": {},
   "outputs": [],
   "source": [
    "results_df = pd.merge(simulation_df, errors_df, on=\"a\", how=\"left\")\n",
    "\n",
    "# computation of the W2 upperbound\n",
    "W2_upperbound = results_df.groupby('a')['w2_error'].mean() #* 1/rescale\n",
    "\n",
    "# generation results W2 \n",
    "mean_all_Euler_W2 = results_df[results_df['scheme'] == 'euler'].groupby('a')['w2'].mean()\n",
    "std_all_Euler_W2 = results_df[results_df['scheme'] == 'euler'].groupby('a')['w2'].std()\n",
    "mean_all_EI_W2 = results_df[results_df['scheme'] == 'semii'].groupby('a')['w2'].mean()\n",
    "std_all_EI_W2 = results_df[results_df['scheme'] == 'semii'].groupby('a')['w2'].std()\n",
    "\n",
    "# generation results VPSDE W2\n",
    "VPSDE_mean_W2 = mean_all_Euler_W2[0.0]\n",
    "VPSDE_std_W2 = std_all_Euler_W2[0.0]\n",
    "\n",
    "# Exact score results W2 \n",
    "true_score_W2_Euler = distances_df[(distances_df[\"scheme\"] == 'euler') & (distances_df[\"loss\"] == \"exact_score\") & (distances_df[\"num_steps\"]==num_steps)]['w2']\n",
    "true_score_W2_EI = distances_df[(distances_df[\"scheme\"] == 'semii') & (distances_df[\"loss\"] == \"exact_score\") & (distances_df[\"num_steps\"]==num_steps)]['w2']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e3dc0028-8b8b-4ac3-89e0-c112a12f9042",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(a_values,W2_upperbound)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9cfe9c70-dc1c-4290-90ae-0f5b256e0c89",
   "metadata": {},
   "outputs": [],
   "source": [
    "######## SELECT TRAINING PARAMETERS W.R.T THE DIMENSION ######################\n",
    "if d==5: \n",
    "    # optimisation parameters\n",
    "    n_epochs = 2 \n",
    "    learning_rate = 1.0e-4\n",
    "    \n",
    "elif d==10: \n",
    "    # optimisation parameters\n",
    "    n_epochs = 30\n",
    "    learning_rate = 1.0e-4  \n",
    "    \n",
    "elif d==25: \n",
    "    # optimisation parameters\n",
    "    n_epochs = 50\n",
    "    learning_rate = 1.0e-3\n",
    "    \n",
    "elif d==50: \n",
    "    # optimisation parameters\n",
    "    n_epochs = 150\n",
    "    learning_rate = 1.0e-3\n",
    "\n",
    "# network_parameters\n",
    "mid_features = 256\n",
    "num_layers = 3\n",
    "batch_size = 64\n",
    "\n",
    "# Monte Carlo estimate samples (for E2 computation in the KL bound)\n",
    "num_mc = 500\n",
    "\n",
    "# number of runs (for replication)\n",
    "rep_size = 1\n",
    "\n",
    "# load data\n",
    "dataloader = DataLoader(training_sample, batch_size=batch_size, shuffle=True)\n",
    "\n",
    "# number of discretisation steps\n",
    "num_step = 150"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "80b80434-478c-4e61-85e3-6a3583f138c3",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "###\n",
    "## TRAINING NN FOR DIFFERENT NOISE SCHEDULES\n",
    "###\n",
    "\n",
    "score_theta_trained = [ [] for _ in a_values ]   # a_values, replicates\n",
    "for k, a in enumerate(a_values):\n",
    "    print(f\"optimization for a = {a}\")\n",
    "    sde.beta.change_a(a)\n",
    "    score_theta_trained[k] = [ ]\n",
    "    for j in range(rep_size):\n",
    "        network = decoder.Decoder(sde.d, mid_features, num_layers).to(device)\n",
    "        optimizer = Adam(network.parameters(), lr=learning_rate)\n",
    "        #loss = diff.loss_conditional(network, sde) # to choose when the score function is not analytically available\n",
    "        loss = diff.loss_explicit(network, sde, diff.explicit_score(sde, training_distribution)) # to choose in the Gaussian case for training with the analytical score function\n",
    "        diff.train(loss, dataloader, n_epochs=n_epochs, optimizer=optimizer)\n",
    "        score_theta_trained[k].append(network)\n",
    "        #torch.save(network.state_dict(), f'models/gaussian_covar/d50_explicit_rescale_standardize_ep150'+f'/model_{k}_{j}_d'+str(d)+'.pt')\n",
    "\n",
    "'''\n",
    "##### TO LOAD PRETRAINED MODELS\n",
    "score_theta_trained = [ [] for _ in a_values ]   # a_values, replicates\n",
    "for k, a in enumerate(a_values):\n",
    "    score_theta_trained[k] = [ ]\n",
    "    for j in range(rep_size):\n",
    "        network = decoder.Decoder(sde.d, mid_features, num_layers) \n",
    "        network.load_state_dict(torch.load(f'models/gaussian_covar/d50_explicit_rescale_ep150'+f'/model_{k}_{j}_d'+str(d)+'.pt'))\n",
    "        score_theta_trained[k].append(network)\n",
    "'''\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f2c880da-4df5-4e81-8be2-b36a624e4ad2",
   "metadata": {},
   "outputs": [],
   "source": [
    "### GAUSSIAN CASE\n",
    "\n",
    "simulation_results = []\n",
    "\n",
    "for k, a in tqdm(enumerate(a_values)):\n",
    "    print(f\"simulation and M for a = {a}\")\n",
    "    sde.beta.change_a(a)\n",
    "    exp_score = diff.explicit_score(sde, dist_emp)\n",
    "\n",
    "    for j, score_theta in enumerate(score_theta_trained[k]):\n",
    "        _, error_approx_sup_L2 = func.compute_E2(dist, sde, score_theta, exp_score, num_steps, num_mc)\n",
    "\n",
    "        for scheme_name, scheme_sampler in [(\"euler\", sp.Euler_Maruyama_discr_sampler)]:#, (\"semii\", sp.EI_discr_sampler)]:\n",
    "            sample = scheme_sampler(init, sde, score_theta, num_steps)\n",
    "            #kl_value = kl(training_distribution, empirical(func.unnormalize(sample,func.normalize(training_sample, rescale)[1], func.normalize(training_sample, rescale)[2],rescale)))\n",
    "            #w2_value = w2(training_distribution, empirical(func.unnormalize(sample, a, b ,rescale)))\n",
    "            w2_value = w2(training_distribution, empirical(sample))\n",
    "\n",
    "            simulation_results.append({\n",
    "                \"a\": a,\n",
    "                \"replication\": j,\n",
    "                \"scheme\": scheme_name,\n",
    "                #\"error_approx_E2\": error_approx_E2,\n",
    "                \"error_approx_sup_L2\": error_approx_sup_L2,\n",
    "                #\"kl\": kl_value,\n",
    "                \"w2\": w2_value,\n",
    "            })\n",
    "\n",
    "simulation_df = pd.DataFrame(simulation_results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "54069d27-e230-4d9d-bb62-49e1fdc80a04",
   "metadata": {},
   "outputs": [],
   "source": [
    "results_df = pd.merge(simulation_df, errors_df, on=\"a\", how=\"left\")\n",
    "\n",
    "# computation of the W2 upperbound\n",
    "W2_upperbound = results_df.groupby('a')['w2_error'].mean() #* 1/rescale\n",
    "\n",
    "# generation results W2 \n",
    "mean_all_Euler_W2 = results_df[results_df['scheme'] == 'euler'].groupby('a')['w2'].mean()\n",
    "std_all_Euler_W2 = results_df[results_df['scheme'] == 'euler'].groupby('a')['w2'].std()\n",
    "mean_all_EI_W2 = results_df[results_df['scheme'] == 'semii'].groupby('a')['w2'].mean()\n",
    "std_all_EI_W2 = results_df[results_df['scheme'] == 'semii'].groupby('a')['w2'].std()\n",
    "\n",
    "# generation results VPSDE W2\n",
    "\n",
    "VPSDE_mean_W2 = mean_all_Euler_W2[0.0]\n",
    "VPSDE_std_W2 = std_all_Euler_W2[0.0]\n",
    "\n",
    "# Exact score results W2 \n",
    "true_score_W2_Euler = distances_df[(distances_df[\"scheme\"] == 'euler') & (distances_df[\"loss\"] == \"exact_score\") & (distances_df[\"num_steps\"]==num_steps)]['w2']\n",
    "true_score_W2_EI = distances_df[(distances_df[\"scheme\"] == 'semii') & (distances_df[\"loss\"] == \"exact_score\") & (distances_df[\"num_steps\"]==num_steps)]['w2']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6c5bcf49-6eee-4a44-a01c-84f030db27ed",
   "metadata": {},
   "outputs": [],
   "source": [
    "##### PLOT WASSERSTEIN ########\n",
    "font = {'family' : 'sans-serif',\n",
    "        'size'   : 16}\n",
    "plt.rc('font', **font)\n",
    "cmap = cm.get_cmap('tab20c')\n",
    "mpl.rcParams['axes.facecolor'] = 'white'\n",
    "mpl.rcParams['axes.grid'] = False\n",
    "mpl.rcParams['axes.edgecolor'] = 'gray'\n",
    "\n",
    "fig, ax1 = plt.subplots()\n",
    "\n",
    "y1_color = cmap(1/20)\n",
    "ax1.plot(a_values, W2_upperbound, color=y1_color, label=\"Upper-bound\",linewidth=3)\n",
    "ax1.fill_between(a_values,\n",
    "                 ([x for x in W2_upperbound]) - simulation_df.groupby('a')['error_approx_sup_L2'].std().to_numpy() ,\n",
    "                 ([x for x in W2_upperbound]) + simulation_df.groupby('a')['error_approx_sup_L2'].std().to_numpy() ,\n",
    "                 fc=y1_color,alpha=.3)\n",
    "ax1.set_xlabel(\"Values of a\")\n",
    "ax1.set_ylabel(\"Upper bound (W2)\", color=y1_color)\n",
    "ax1.tick_params(axis='y', labelcolor=y1_color)\n",
    "\n",
    "y2_color = cmap(5/20)\n",
    "y3_color = cmap(5/20)\n",
    "color_0 = 'gray' #mean_all_Euler_KL\n",
    "color_bis = 'black'\n",
    "ax2 = ax1.twinx()\n",
    "\n",
    "ax2.plot(a_values, mean_all_Euler_W2, color=y2_color, label= r\"$\\mathcal{W}_2(\\pi_{\\rm data},\\hat{\\pi})$ (NN) \",linewidth=3)\n",
    "ax2.plot(a_values, true_score_W2_Euler, color=y2_color, label = r\"$\\mathcal{W}_2(\\pi_{\\rm data}$,$\\hat{\\pi}$) (exact score)\",linewidth=3, linestyle='dotted')\n",
    "ax2.fill_between(a_values,\n",
    "                 mean_all_Euler_W2-std_all_Euler_W2,\n",
    "                 mean_all_Euler_W2+std_all_Euler_W2,\n",
    "                 fc=y2_color,alpha=.3)\n",
    "ax2.plot(a_values, np.full_like(a_values, VPSDE_mean_W2), color=y3_color, label=\"VPSDE (NN)\",linewidth=2, linestyle='--')\n",
    "ax2.fill_between(a_values,\n",
    "                 VPSDE_mean_W2-VPSDE_std_W2,\n",
    "                 VPSDE_mean_W2+VPSDE_std_W2,\n",
    "                 fc=y3_color,alpha=.2)\n",
    "ax2.set_ylabel(r\"$\\mathcal{W}_2$ distance\", color=y2_color)\n",
    "ax2.tick_params(axis='y', labelcolor=y2_color)\n",
    "\n",
    "fig.tight_layout()\n",
    "\n",
    "plt.legend(loc = 'center right')\n",
    "#plt.savefig(\"divlambdamax_COV.pdf\")\n",
    "plt.show() "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f8e838e5-f94c-41a9-8acf-1770063cf73d",
   "metadata": {},
   "outputs": [],
   "source": [
    "W2_upperbound = results_df.groupby('a')['w2_error'].mean() #* 1/rescale\n",
    "plt.plot(W2_upperbound)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
