{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "b370c313-6db4-4ee9-8711-4999204acc8a",
   "metadata": {},
   "source": [
    "## Nonlinear Constraint Projection for PME (Table 3)\n",
    "\n",
    "This notebook implements nonlinear constraint projection for solving the **Porous Medium Equation (PME)** with parameterized nonlinearity \\( m \\). It is used to generate results for **Table 3** in the main paper by enforcing hard nonlinear constraints—such as **mass conservation**—on the predicted solution.\n",
    "\n",
    "The projection ensures that the predicted mean $ \\mu $ satisfies a nonlinear constraint of the form:\n",
    "\n",
    "$$\n",
    "\\frac{d}{dt} \\int_{\\Omega} u(t,x) d \\Omega = F(u(t,x_0)) - F(u(t, x_N)).\n",
    "$$\n",
    "\n",
    "---\n",
    "\n",
    "### Switching Between Projection Modes\n",
    "\n",
    "Two types of projection are supported in this notebook:\n",
    "\n",
    "- **Oblique projection** (default): projection is computed relative to the weighted inner product defined by the inverse variance $ Q = \\Sigma^{-1} $.\n",
    "- **Orthogonal projection**: projection is computed with respect to the standard Euclidean norm.\n",
    "\n",
    "To change between the two, simply replace the projection function:\n",
    "\n",
    "```python\n",
    "# For oblique projection (default):\n",
    "project_and_stats(...)\n",
    "\n",
    "# For orthogonal projection:\n",
    "project_and_stats_orth(...)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 160,
   "id": "6b3ccdb8-4fd8-4224-b23e-979fa8228059",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = '1'\n",
    "\n",
    "import numpy as np\n",
    "import torch\n",
    "import matplotlib.pyplot as plt\n",
    "from models.FNO2d import FNO2d\n",
    "from models.DiverseFNO2d import DiverseFNO2d\n",
    "from models.UncertainNO import *\n",
    "import utils\n",
    "from einops import rearrange, reduce, repeat\n",
    "import os\n",
    "from docopt import docopt\n",
    "import dill\n",
    "from datasets import *\n",
    "import probconserv\n",
    "import sys\n",
    "import torch.optim as optim\n",
    "\n",
    "# args = docopt(__doc__)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 161,
   "id": "3f650c83-1845-4ef1-9b8b-2aaf18d83d08",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'2.6.0+cu118'"
      ]
     },
     "execution_count": 161,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.__version__"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 162,
   "id": "dabd741a-54b6-4556-8585-7ed8dfaff61c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'Tesla V100S-PCIE-32GB'"
      ]
     },
     "execution_count": 162,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.cuda.get_device_name()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 163,
   "id": "5fdcc8cd-82d9-4cec-888d-806541259f2c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The autoreload extension is already loaded. To reload it, use:\n",
      "  %reload_ext autoreload\n"
     ]
    }
   ],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 164,
   "id": "1d5eb530-ab43-4054-8953-9efdb86b3106",
   "metadata": {},
   "outputs": [],
   "source": [
    "args = {'--batch_size': '20',\n",
    " '--dataset': 'PME_1D',\n",
    " '--dataset_params': '4,5',\n",
    " '--epochs': '200',\n",
    " '--fno_modes': '12',\n",
    " '--fno_width': '32',\n",
    " '--grid_len': '100',\n",
    " '--lr': '1e-3',\n",
    " '--m.drop_prob': '0.1',\n",
    " '--m.n_models': '10',\n",
    " '--m.n_regularize': '5',\n",
    " '--m.reg_strength': '1',\n",
    " '--m.reg_type': 'weights_l2',\n",
    " '--model': 'OutputVarFNO2d',\n",
    " '--n_samples': '200',\n",
    " '--no_train': False,\n",
    " '--ood_dataset_params': None,\n",
    " '--predict_time': '0,-1,5',\n",
    " '--seed': '0',\n",
    " '--time_len': '100',\n",
    " '--tplot': '0.5',\n",
    " '--train_ood_dataset_params': '4,5'}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 165,
   "id": "9ac91626-a64a-4ae2-ae5f-f63f492a7d4d",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = \"cuda\" if torch.cuda.is_available() else \"cpu\" \n",
    "experiment_name = \"trial\"\n",
    "# print(f\"Experiment: {experiment_name}\")\n",
    "# print(args)\n",
    "save_args = utils.filter_config(args, [\"generate\", \"--no_train\", \"--ood_dataset_params\", \"--tplot\"], mode=\"remove\")  # Also removes \".\" keys\n",
    "\n",
    "is_train = not bool(args[\"--no_train\"])\n",
    "\n",
    "# Parameters\n",
    "n_x = int(args[\"--grid_len\"])\n",
    "n_t = int(args[\"--time_len\"])\n",
    "n_samples = int(args[\"--n_samples\"])\n",
    "n_train = int(0.8 * n_samples)\n",
    "n_valid = int(0.2 * n_samples)\n",
    "n_test = n_samples // 2\n",
    "\n",
    "is_markov = False\n",
    "\n",
    "dataset = args[\"--dataset\"]\n",
    "dataset_params = [float(val) for val in args[\"--dataset_params\"].split(\",\")]\n",
    "train_ood_dataset_params = [float(val) for val in args[\"--train_ood_dataset_params\"].split(\",\")]\n",
    "ood_dataset_params = train_ood_dataset_params\n",
    "if not is_train:\n",
    "    ood_dataset_params = [float(val) for val in args[\"--ood_dataset_params\"].split(\",\")]\n",
    "\n",
    "tpred = [int(val) for val in args[\"--predict_time\"].split(\",\")]\n",
    "\n",
    "fno_modes = int(args[\"--fno_modes\"])\n",
    "fno_width = int(args[\"--fno_width\"])\n",
    "\n",
    "batch_size = int(args[\"--batch_size\"])\n",
    "lr = float(args[\"--lr\"])\n",
    "epochs = int(args[\"--epochs\"])\n",
    "step_size = 50\n",
    "gamma = 0.5\n",
    "# ################\n",
    "\n",
    "# Set seed\n",
    "utils.set_seed(int(args[\"--seed\"]))\n",
    "\n",
    "# Generate dataset\n",
    "if dataset.lower() == \"HeatEquation_1D\".lower():\n",
    "    t = torch.linspace(0, 1, n_t)\n",
    "    grid = torch.linspace(0, 2 * np.pi, n_x)\n",
    "    dataset_class = HeatEquation_1D\n",
    "elif dataset.lower() == \"PME_1D\".lower():\n",
    "    t = torch.linspace(0, 1, n_t)\n",
    "    grid = torch.linspace(0, 1, n_x)\n",
    "    dataset_class = PME_1D\n",
    "elif dataset.lower() == \"StefanPME_1D\".lower():\n",
    "    t = torch.linspace(0, 1, n_t)\n",
    "    grid = torch.linspace(0, 1, n_x)\n",
    "    dataset_class = StefanPME_1D\n",
    "elif dataset.lower() == \"LinearAdvection_1D\".lower():\n",
    "    t = torch.linspace(0, 1, n_t)\n",
    "    grid = torch.linspace(0, 1, n_x)\n",
    "    dataset_class = LinearAdvection_1D\n",
    "else:\n",
    "    raise NotImplementedError\n",
    "\n",
    "t_sliced = t[slice(*tpred)]\n",
    "T = len(t_sliced)\n",
    "\n",
    "def get_xy_from_pu(p, u, is_markov=False):\n",
    "    T = u.shape[2]\n",
    "    #TODO: What does is_markov do here?\n",
    "    if is_markov:\n",
    "        x0, y0 = p, u\n",
    "        \n",
    "        y0_vectorized = rearrange(y0[:, :, 0:T-1], \"nf nx nt 1 -> (nf nt) nx 1\")\n",
    "        x0 = repeat(x0, \"nf nx 1 -> (nf nt) nx 1\", nt=T-1)\n",
    "        x = torch.cat([x0, y0_vectorized], dim=-1)\n",
    "        \n",
    "        y = rearrange(y0[:, :, 1:T], \"nf nx nt 1 -> (nf nt) nx 1\")\n",
    "    else:\n",
    "        x, y = p, u\n",
    "        x = repeat(x, \"nf nx 1 -> nf nx T 1\", T=T)\n",
    "    return x, y\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 166,
   "id": "b02f6033-3977-4d51-8f39-024a376dbabf",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Here 160\n",
      "torch.Size([160, 100, 1]) torch.Size([160, 100, 20, 1])\n",
      "torch.Size([160, 100, 20, 1]) torch.Size([160, 100, 20, 1])\n"
     ]
    }
   ],
   "source": [
    "if is_train:\n",
    "    # Train data\n",
    "    print(\"Here\", n_train)\n",
    "    a, u, p = dataset_class.generate_dataset(n_train, grid, t, tpred, *dataset_params)\n",
    "    print(a.shape, u.shape)\n",
    "    x_train, y_train = get_xy_from_pu(p, u, is_markov=is_markov)\n",
    "\n",
    "    # Validation data\n",
    "    a, u, p = dataset_class.generate_dataset(n_valid, grid, t, tpred, *dataset_params)\n",
    "    x_valid, y_valid = get_xy_from_pu(p, u, is_markov=is_markov)\n",
    "\n",
    "    # In-distribution test data\n",
    "    a, u, p = dataset_class.generate_dataset(n_test, grid, t, tpred, *dataset_params)\n",
    "    x_id_test, y_id_test = get_xy_from_pu(p, u, is_markov=is_markov)\n",
    "\n",
    "    # Out-of-distribution inputs only\n",
    "    a, u, p = dataset_class.generate_dataset(n_test, grid, t, tpred, *train_ood_dataset_params)\n",
    "    x_ood_test, y_ood_test = get_xy_from_pu(p, u, is_markov=is_markov)\n",
    "\n",
    "    # Data loaders\n",
    "    train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_train, y_train), \n",
    "                                            batch_size=batch_size, shuffle=True)\n",
    "    valid_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_valid, y_valid), \n",
    "                                            batch_size=batch_size, shuffle=False)\n",
    "    id_test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_id_test, y_id_test), \n",
    "                                            batch_size=batch_size, shuffle=False)\n",
    "    ood_test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_ood_test, y_ood_test), \n",
    "                                            batch_size=batch_size, shuffle=False)\n",
    "else:\n",
    "    # OOD test data\n",
    "    a, u, p = dataset_class.generate_dataset(n_test, grid, t, tpred, *ood_dataset_params)\n",
    "    x_ood_test, y_ood_test = get_xy_from_pu(p, u, is_markov=is_markov)\n",
    "    ood_test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_ood_test, y_ood_test), \n",
    "                                            batch_size=batch_size, shuffle=False)\n",
    "\n",
    "print(x_train.shape, y_train.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 167,
   "id": "b64f1eeb-0e6e-4167-a082-5479ac453cb9",
   "metadata": {},
   "outputs": [],
   "source": [
    "# tpred = torch.tensor(tpred).to(device), dataset_class = dataset_class, t=t.to(device), grid_train=grid.to(device))\n",
    "# stop = time.time()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 168,
   "id": "2bbd9b90-2b82-439d-b7e1-a642fdd2d348",
   "metadata": {},
   "outputs": [],
   "source": [
    "constraint_context = {\n",
    "    \"t\": t.to(device),\n",
    "    \"tpred\": torch.tensor(tpred).to(device),\n",
    "    \"grid_train\": grid.to(device),\n",
    "    \"dataset_class\": dataset_class\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a7e5c61c-2b59-466f-9b4a-6157fa29ad4b",
   "metadata": {},
   "outputs": [],
   "source": [
    "uq = False\n",
    "model_name = args[\"--model\"]\n",
    "n_models = 1\n",
    "fno_modes2 = min(fno_modes, 12)\n",
    "if args[\"--model\"].lower() == \"FNO2d\".lower():\n",
    "    FNO2d_params = {\"modes1\": fno_modes, \"modes2\": fno_modes2, \"width\": fno_width, \"output_var\": True}\n",
    "    model = FNO2d(**FNO2d_params).to(device)\n",
    "elif args[\"--model\"].lower().startswith(\"EnsembleFNO2d\".lower()):\n",
    "    FNO2d_params = {\"modes1\": fno_modes, \"modes2\": fno_modes2, \"width\": fno_width}\n",
    "    n_models = int(args[\"--m.n_models\"])\n",
    "    utils.filter_config(args, [\"--m.n_models\"], mode=\"add\", new_config=save_args)\n",
    "    model = EnsembleNO(base_model_class=FNO2d, base_model_params=FNO2d_params, n_models=n_models)\n",
    "    uq = True\n",
    "elif args[\"--model\"].lower().startswith(\"BayesianFNO2d\".lower()):\n",
    "    FNO2d_params = {\"modes1\": fno_modes, \"modes2\": fno_modes2, \"width\": fno_width}\n",
    "    model = BayesianNO(base_model_class=FNO2d, base_model_params=FNO2d_params)\n",
    "    uq = True\n",
    "elif args[\"--model\"].lower().startswith(\"MCDropoutFNO2d\".lower()):\n",
    "    FNO2d_params = {\"modes1\": fno_modes, \"modes2\": fno_modes2, \"width\": fno_width}\n",
    "    dropout = float(args[\"--m.drop_prob\"])\n",
    "    n_dropouts = int(args[\"--m.n_models\"])\n",
    "    utils.filter_config(args, [\"--m.n_models\", \"--m.drop_prob\"], mode=\"add\", new_config=save_args)\n",
    "    model = MCDropoutNO(base_model_class=FNO2d, base_model_params=FNO2d_params, dropout=dropout, n_dropouts=n_dropouts)\n",
    "    uq = True\n",
    "elif args[\"--model\"].lower().startswith(\"OutputVarFNO2d\".lower()):\n",
    "    FNO2d_params = {\"modes1\": fno_modes, \"modes2\": fno_modes2, \"width\": fno_width}\n",
    "    #model = OutputVarNO(base_model_class=FNO2d, probconserv=False, base_model_params=FNO2d_params)\n",
    "    model = ProbHardE2E(base_model_class=FNO2d, probconserv=False, base_model_params=FNO2d_params, constraint_context=constraint_context)\n",
    "    uq = True\n",
    "elif args[\"--model\"].lower().startswith(\"DiverseFNO2d\".lower()):\n",
    "    FNO2d_params = {\"modes1\": fno_modes, \"modes2\": fno_modes2, \"width\": fno_width}\n",
    "    lam = float(args[\"--m.reg_strength\"])\n",
    "    reg_type = args[\"--m.reg_type\"]\n",
    "    n_models = int(args[\"--m.n_models\"])\n",
    "    n_regularize = int(args[\"--m.n_regularize\"])\n",
    "    utils.filter_config(args, [\"--m.n_models\", \"--m.reg_strength\", \"--m.reg_type\", \"--m.n_regularize\"], mode=\"add\", new_config=save_args)\n",
    "    model = DiverseFNO2d(reg_loss=reg_type, n_outputs=n_models, bias_last=False, lam=lam, n_regularize=n_regularize, **FNO2d_params).to(device)\n",
    "    uq = True\n",
    "else:\n",
    "    raise NotImplementedError"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 170,
   "id": "4333d6f8-fe2c-4caa-a889-efae589c31ff",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([160, 100, 20, 1])"
      ]
     },
     "execution_count": 170,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y_train.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 171,
   "id": "96aa9ee0-a30e-458b-85e2-5e7dbbd89dc4",
   "metadata": {},
   "outputs": [],
   "source": [
    "mu_true = torch.mean(y_train, dim = 0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 172,
   "id": "fc1ac32d-6fb2-4293-8ea7-f90486a32dce",
   "metadata": {},
   "outputs": [],
   "source": [
    "var_true = torch.var(y_train, dim = 0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 173,
   "id": "f07c25ce-49e4-4331-81a3-ee678808a6b4",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([100, 20, 1])"
      ]
     },
     "execution_count": 173,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "mu_true.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 174,
   "id": "0f4d18b3-1988-4417-9122-3a8709048fc2",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([0.0260, 0.0270, 0.0280, 0.0290, 0.0299, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "        0.0000])\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAjcAAAHHCAYAAABDUnkqAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAABOa0lEQVR4nO3deVxU9f4/8NfsA7K4gIDIFZfcTQyC0MxKitJrWldDLUVKW9RK+dovaRG1EivzUl2Uq7ld06TUrKtmV0lumtwwl9JcSkUljc0FCJSBmc/vD53RkQHmDLMBr+fjMY+HnDlnznsO4Lz4LOcjE0IIEBERETURclcXQERERGRPDDdERETUpDDcEBERUZPCcENERERNCsMNERERNSkMN0RERNSkMNwQERFRk8JwQ0RERE0Kww0RERE1KQw35FRZWVmQyWTIyspyyvlmz54NmUzmlHOR/dx777249957XV2GGZlMhtmzZ7u6jAYrKCjAyJEj0aZNG8hkMqSmprq6JCK7Y7ghu1i5ciVkMpnpodVq0bVrV0ydOhUFBQV2OcfWrVstfrhUVFRg9uzZTgtMRI3Z9OnT8c033yApKQmrV6/GQw895NDz3fz/glwuR7t27fDggw/W+H0NDQ2FTCZDTEyMxddZunSp6XV+/PFH03bjHzC1PfLz862uNScnB5MnT0Z4eDhUKlWtfxidPn3a7BwqlQp+fn7o378/Xn31VZw9e9bqc5JjKF1dADUtc+fORceOHXH16lXs3r0bixcvxtatW3H48GF4eno26LW3bt2KtLS0GgGnoqICc+bMAYAaf+2//vrrmDlzZoPOSwQAV65cgVLZ+P/L/PbbbzF8+HDMmDHDaed84IEHMH78eAghkJubi0WLFuH+++/Hli1b8PDDD5v202q12LlzJ/Lz8xEYGGj2GmvWrIFWq8XVq1ctnmPx4sXw8vKqsb1ly5ZW17l161Z8/PHHuP3229GpUyf8+uuvde4/ZswYDBkyBAaDAZcuXcLevXuRmpqKDz74AMuWLcPo0aOtPjfZV+P/TSW38vDDDyMiIgIAMHHiRLRp0wYLFy7El19+iTFjxji9HqVS2SQ+kMg1DAYDdDodtFottFqtq8uxi8LCQkkf+PW5evUq1Go15PLaOwK6du2KJ5980vT1o48+ittvvx2pqalm4WbAgAHYu3cvMjIy8NJLL5m2//7779i1axceffRRbNiwweI5Ro4cCT8/vwa9l+effx6vvPIKPDw8MHXq1HrDzR133GH2vgDgzJkzePDBBxEfH48ePXqgb9++DaqJbMNuKXKo+++/HwCQm5tb536ff/45wsPD4eHhAT8/Pzz55JM4d+6c6fkJEyYgLS0NgHkz9+nTp+Hv7w8AmDNnjmm7sXXH0pgbmUyGqVOnYtOmTejduzc0Gg169eqFbdu21agrKysLERER0Gq16Ny5M/75z39aPY7n3nvvRe/evfHzzz9j0KBB8PT0RJcuXbB+/XoAwH//+19ERUXBw8MD3bp1w44dO2q8xrlz5/DUU08hICDAVOfy5cvN9tHpdJg1axbCw8Ph6+uLFi1aYODAgdi5c6fZfsam9AULFmDJkiXo3LkzNBoN7rzzTuzdu7fe93Px4kXMmDEDffr0gZeXF3x8fPDwww/jp59+qnHNZDIZPvvsM7z99tto3749tFotBg8ejBMnTtR4XWMtHh4eiIyMxK5du+qtBQB69+6N++67r8Z2g8GA4OBgjBw50rRtwYIF6N+/P9q0aQMPDw+Eh4ebvg83M/5srFmzBr169YJGozH9XNw65ubMmTOYPHkyunXrBg8PD7Rp0wajRo3C6dOnzV7T2GX7/fffIzExEf7+/mjRogUeffRRFBUV1ajh66+/xqBBg+Dt7Q0fHx/ceeedWLt2rdk+P/zwAx566CH4+vrC09MTgwYNwvfff1/n9TLWIYRAWlqa6XfF6NSpUxg1ahRat24NT09P3HXXXdiyZYvZaxi/t+vWrcPrr7+O4OBgeHp6orS0tM5z36pPnz7w8/Or8f+CVqvFY489VuP9fvrpp2jVqhViY2MlnUeqgIAAeHh4NOg1OnTogJUrV0Kn0+Hdd9+1U2UkFf+kJYc6efIkAKBNmza17rNy5UokJCTgzjvvREpKCgoKCvDBBx/g+++/x4EDB9CyZUs8++yzOH/+PLZv347Vq1ebjvX398fixYvx/PPP49FHH8Vjjz0GALj99tvrrGv37t3YuHEjJk+eDG9vb3z44Yf429/+hrNnz5pqPXDgAB566CEEBQVhzpw50Ov1mDt3rilMWePSpUv461//itGjR2PUqFFYvHgxRo8ejTVr1mDatGl47rnnMHbsWLz33nsYOXIk8vLy4O3tDeDawM+77rrL9IHr7++Pr7/+Gk8//TRKS0sxbdo0AEBpaSk+/vhjjBkzBpMmTUJZWRmWLVuG2NhY5OTkICwszKymtWvXoqysDM8++yxkMhneffddPPbYYzh16hRUKlWt7+XUqVPYtGkTRo0ahY4dO6KgoAD//Oc/MWjQIBw5cgTt2rUz23/+/PmQy+WYMWMGSkpK8O677+KJJ57ADz/8YNpn2bJlePbZZ9G/f39MmzYNp06dwiOPPILWrVsjJCSkzmsbFxeH2bNn1+jC2L17N86fP2/WJfDBBx/gkUcewRNPPAGdTod169Zh1KhR2Lx5M4YOHWr2ut9++y0+++wzTJ06FX5+fggNDbV4/r1792LPnj0YPXo02rdvj9OnT2Px4sW49957ceTIkRrdsC+88AJatWqF5ORknD59GqmpqZg6dSoyMjJM+6xcuRJPPfUUevXqhaSkJLRs2RIHDhzAtm3bMHbsWFN9Dz/8MMLDw5GcnAy5XI4VK1bg/vvvx65duxAZGWmx3nvuuQerV6/GuHHjTN1ERgUFBejfvz8qKirw4osvok2bNli1ahUeeeQRrF+/Ho8++qjZa7355ptQq9WYMWMGKisroVar6/hO1XTp0iVcunQJXbp0qfHc2LFj8eCDD+LkyZPo3LkzgGs/syNHjqzz5/PixYs1timVSru2UlkrOjoanTt3xvbt251+brpOENnBihUrBACxY8cOUVRUJPLy8sS6detEmzZthIeHh/j999+FEELs3LlTABA7d+4UQgih0+lE27ZtRe/evcWVK1dMr7d582YBQMyaNcu0bcqUKcLSj2xRUZEAIJKTk2s8l5ycXOMYAEKtVosTJ06Ytv30008CgPjoo49M24YNGyY8PT3FuXPnTNt+++03oVQqLdZxq0GDBgkAYu3ataZtx44dEwCEXC4X//vf/0zbv/nmGwFArFixwrTt6aefFkFBQaK4uNjsdUePHi18fX1FRUWFEEKI6upqUVlZabbPpUuXREBAgHjqqadM23JzcwUA0aZNG3Hx4kXT9i+//FIAEP/+97/rfD9Xr14Ver3ebFtubq7QaDRi7ty5pm3G73GPHj3M6vrggw8EAHHo0CEhxI3vfVhYmNl+S5YsEQDEoEGD6qzn+PHjNb5nQggxefJk4eXlZbo+QgizfxvP3bt3b3H//febbTd+b3755Zca57v1Z+zW1xRCiOzsbAFA/Otf/zJtM/5uxMTECIPBYNo+ffp0oVAoxOXLl4UQQly+fFl4e3uLqKgos98FIYTpOIPBIG677TYRGxtr9loVFRWiY8eO4oEHHqhRk6X3MWXKFLNt06ZNEwDErl27TNvKyspEx44dRWhoqOn7bvzedurUyeL7r+18Tz/9tCgqKhKFhYXihx9+EIMHDxYAxPvvv2/ar0OHDmLo0KGiurpaBAYGijfffFMIIcSRI0cEAPHf//7XdC337t1rOs74O27p0a1bN6tqtKS2/2+EuPG79N5779V6/PDhwwUAUVJSYnMNZDt2S5FdxcTEwN/fHyEhIRg9ejS8vLzwxRdfIDg42OL+P/74IwoLCzF58mSzMQ1Dhw5F9+7dazSL27NO41+FwLWWHh8fH5w6dQoAoNfrsWPHDowYMcKsRaJLly5mYwTq4+XlZdaC0K1bN7Rs2RI9evRAVFSUabvx38bzCyGwYcMGDBs2DEIIFBcXmx6xsbEoKSnB/v37AQAKhcL0l7PBYMDFixdRXV2NiIgI0z43i4uLQ6tWrUxfDxw40OzctdFoNKZxFXq9HhcuXICXlxe6detm8TwJCQlmf9Hfeh7j9/65554z22/ChAnw9fWtsxbg2jiOsLAws5YPvV6P9evXY9iwYWbdCzf/+9KlSygpKcHAgQMt1j1o0CD07Nmz3vPf/JpVVVW4cOECunTpgpYtW1p83WeeecasG2jgwIHQ6/U4c+YMAGD79u0oKyvDzJkza4zvMR538OBB/Pbbbxg7diwuXLhg+pkoLy/H4MGD8d1338FgMNRb+622bt2KyMhI3H333aZtXl5eeOaZZ3D69GkcOXLEbP/4+HhJ3TfLli2Dv78/2rZti6ioKFMXnbH18WYKhQKPP/44Pv30UwDXBhKHhISYfn5qs2HDBmzfvt3ssWLFCqtrtDfj4OaysjKX1dCcsVuK7CotLQ1du3aFUqlEQEAAunXrVudAQ+N/7N26davxXPfu3bF7926H1PmXv/ylxrZWrVrh0qVLAK4Nurxy5YrFZnNL22rTvn37GuNzfH19a3S5GD/MjecvKirC5cuXsWTJEixZssTiaxcWFpr+vWrVKrz//vs4duwYqqqqTNs7duxY47hb37sx6BjPXRuDwYAPPvgAixYtQm5uLvR6vek5S92O9Z3H+L2/7bbbzPZTqVTo1KlTnbUYxcXF4dVXX8W5c+cQHByMrKwsFBYWIi4uzmy/zZs346233sLBgwdRWVlp2m5p7JSla2bJlStXkJKSghUrVuDcuXMQQpieKykpqbF/fdfD2IXbu3fvWs/522+/AbgWLmpTUlJiFl6tcebMGbOwbdSjRw/T8zfXZe01Mho+fDimTp0KmUwGb29v9OrVCy1atKh1/7Fjx+LDDz/ETz/9hLVr12L06NH1jnO75557Gjyg2J7+/PNPADB1M5NzMdyQXUVGRppmS7kzhUJhcfvNH1COPE995zf+9f3kk0/W+kFmHFf0ySefYMKECRgxYgRefvlltG3bFgqFAikpKaYPTCnnrs28efPwxhtv4KmnnsKbb76J1q1bQy6XY9q0aRZbC5xxjePi4pCUlITPP/8c06ZNw2effQZfX1+ze7fs2rULjzzyCO655x4sWrQIQUFBUKlUWLFiRY2BqwCsbpF44YUXsGLFCkybNg3R0dHw9fWFTCbD6NGjHXY9jK/73nvv1RhLZWRpOrS9SR102759+1rvX2NJVFQUOnfujGnTpiE3N9c03qgxOXz4MNq2bQsfHx9Xl9IsMdyQS3Xo0AEAcPz4cdPMKqPjx4+bngcs/5Vd1/aGaNu2LbRarcXZPZa22Zu/vz+8vb2h1+vr/VBYv349OnXqhI0bN5pdi+TkZLvWtH79etx3331YtmyZ2fbLly/b9Bez8Xv722+/mX3vq6qqkJuba9UU2o4dOyIyMhIZGRmYOnUqNm7ciBEjRkCj0Zj22bBhA7RaLb755huz7Q3tsli/fj3i4+Px/vvvm7ZdvXoVly9ftun1jN2khw8frrV10LiPj4+PpLBQnw4dOuD48eM1th87dsz0vLONGTMGb731Fnr06FFrkHNX2dnZOHnyZI1p4uQ8HHNDLhUREYG2bdsiPT3drLvg66+/xtGjR81mshibsW/98DDOSrH1Q8UShUKBmJgYbNq0CefPnzdtP3HiBL7++mu7naeu8//tb3/Dhg0bcPjw4RrP3zyF2NgicHMLwA8//IDs7Gy713RrK8Pnn39uNmVfioiICPj7+yM9PR06nc60feXKlZK+l3Fxcfjf//6H5cuXo7i4uEaXlEKhgEwmM+tGO336NDZt2mRT3Te/7q3X46OPPjI7jxQPPvggvL29kZKSUuNGdcbzhIeHo3PnzliwYIGp2+NmlqaWW2PIkCHIyckx+5kpLy/HkiVLEBoaatUYJHubOHEikpOTzcJjY3DmzBlMmDABarUaL7/8sqvLabbYckMupVKp8M477yAhIQGDBg3CmDFjTFPBQ0NDMX36dNO+4eHhAIAXX3wRsbGxUCgUGD16NDw8PNCzZ09kZGSga9euaN26NXr37l3n2AVrzJ49G//5z38wYMAAPP/889Dr9fjHP/6B3r174+DBgw16bWvMnz8fO3fuRFRUFCZNmoSePXvi4sWL2L9/P3bs2GGa+vrXv/4VGzduxKOPPoqhQ4ciNzcX6enp6Nmzp8UPQFv99a9/xdy5c5GQkID+/fvj0KFDWLNmjdXjY26lUqnw1ltv4dlnn8X999+PuLg45ObmYsWKFZJe8/HHH8eMGTMwY8YMtG7dukaLxtChQ7Fw4UI89NBDGDt2LAoLC5GWloYuXbrg559/tql24Nr1WL16NXx9fdGzZ09kZ2djx44ddd72oC4+Pj74+9//jokTJ+LOO+/E2LFj0apVK/z000+oqKjAqlWrIJfL8fHHH+Phhx9Gr169kJCQgODgYJw7dw47d+6Ej48P/v3vf0s+98yZM/Hpp5/i4YcfxosvvojWrVtj1apVyM3NxYYNG+ocN+coHTp0kLSW1/r16y12yT3wwAMICAiw6jXOnDljutWEcYmHt956y1TPuHHjzPbfv38/PvnkExgMBly+fBl79+7Fhg0bIJPJsHr16npvSUEO5JpJWtTUWJqiacmtU8GNMjIyRL9+/YRGoxGtW7cWTzzxhGn6uFF1dbV44YUXhL+/v5DJZGbTNPfs2SPCw8OFWq02m7Jb21TwW6fCCnFtKmp8fLzZtszMTNGvXz+hVqtF586dxccffyz+7//+T2i12nquyLWp4L169bJ4nqFDh9bYbqmugoICMWXKFBESEiJUKpUIDAwUgwcPFkuWLDHtYzAYxLx580SHDh2ERqMR/fr1E5s3bxbx8fGiQ4cOpv3qmr568zWrzdWrV8X//d//iaCgIOHh4SEGDBggsrOzxaBBg8ymbRu/x59//rnZ8cbz3zzdXQghFi1aJDp27Cg0Go2IiIgQ3333XY3XrM+AAQMEADFx4kSLzy9btkzcdtttQqPRiO7du4sVK1ZI+tkwPnfzNbp06ZJISEgQfn5+wsvLS8TGxopjx47V+Dmq7Xejtt+Fr776SvTv3194eHgIHx8fERkZKT799FOzfQ4cOCAee+wx0aZNG6HRaESHDh3E448/LjIzM+u5UrW/x5MnT4qRI0eKli1bCq1WKyIjI8XmzZst1nzr99aW892qtt+Lm0mdCm7p+tbF+P4sPW7+eTT+LBsfSqVStG7dWkRFRYmkpCRx5swZq89JjiETws4jKImauBEjRuCXX34xzVwhIiL3wjE3RHW4cuWK2de//fYbtm7dWmOBTiIich9suSGqQ1BQECZMmIBOnTrhzJkzWLx4MSorK3HgwIEa92chIvdUVFRU50BvtVqN1q1bO7EicjSGG6I6JCQkYOfOncjPz4dGo0F0dDTmzZuHO+64w9WlEZGVQkNDTTeNtGTQoEHIyspyXkHkcAw3RETUpH3//fc1uphv1qpVK9NsTGoaGG6IiIioSeGAYiIiImpSmt1N/AwGA86fPw9vb2+H3LafiIiI7E8IgbKyMrRr167eG0s2u3Bz/vz5GisyExERUeOQl5eH9u3b17lPsws3xuXn8/LyuForERFRI1FaWoqQkBDT53hdml24MXZF+fj4MNwQERE1MtYMKeGAYiIiImpSGG6IiIioSWG4ISIioiaF4YaIiIiaFIYbIiIialIYboiIiKhJYbghIiKiJsXl4SYtLQ2hoaHQarWIiopCTk5OnfunpqaiW7du8PDwQEhICKZPn46rV686qVoiIiJydy4NNxkZGUhMTERycjL279+Pvn37IjY2FoWFhRb3X7t2LWbOnInk5GQcPXoUy5YtQ0ZGBl599VUnV05ERETuyqXhZuHChZg0aRISEhLQs2dPpKenw9PTE8uXL7e4/549ezBgwACMHTsWoaGhePDBBzFmzJh6W3uIiIio+XBZuNHpdNi3bx9iYmJuFCOXIyYmBtnZ2RaP6d+/P/bt22cKM6dOncLWrVsxZMiQWs9TWVmJ0tJSswcRERE1XS5bW6q4uBh6vR4BAQFm2wMCAnDs2DGLx4wdOxbFxcW4++67IYRAdXU1nnvuuTq7pVJSUjBnzhy71k5ERETuy+UDiqXIysrCvHnzsGjRIuzfvx8bN27Eli1b8Oabb9Z6TFJSEkpKSkyPvLw8J1ZMREREzuaylhs/Pz8oFAoUFBSYbS8oKEBgYKDFY9544w2MGzcOEydOBAD06dMH5eXleOaZZ/Daa69BLq+Z1TQaDTQajf3fABEREbkll7XcqNVqhIeHIzMz07TNYDAgMzMT0dHRFo+pqKioEWAUCgUAQAjhuGKtVFmth97g+jqIiIiaM5e13ABAYmIi4uPjERERgcjISKSmpqK8vBwJCQkAgPHjxyM4OBgpKSkAgGHDhmHhwoXo168foqKicOLECbzxxhsYNmyYKeS4khBA2dUqtPRUu7oUIiKiZsul4SYuLg5FRUWYNWsW8vPzERYWhm3btpkGGZ89e9aspeb111+HTCbD66+/jnPnzsHf3x/Dhg3D22+/7aq3UEPJFYYbIiIiV5IJd+jPcaLS0lL4+vqipKQEPj4+dn3tq1V6nCj8Ez2CfKCQy+z62kRERM2ZlM/vRjVbqjEwdk0RERGRazDcOMDlCoYbIiIiV2G4cYA/K6s5a4qIiMhFGG4cQAig9Apbb4iIiFyB4cZBShhuiIiIXILhxkHYNUVEROQaDDcOwq4pIiIi12C4cSB2TRERETkfw40D/VlZjWq9wdVlEBERNSsMNw507YZ+1a4ug4iIqFlhuHEwhhsiIiLnYrhxsD8rGW6IiIicieHGwfQGgSs6vavLICIiajYYbpygrJKzpoiIiJyF4cYJ/uS4GyIiIqdhuHGCCp0eBt6tmIiIyCkYbpxACOBPHVtviIiInIHhxknYNUVEROQcDDdOwinhREREzsFw4ySVVQboqrkUAxERkaMx3DgRW2+IiIgcj+HGicoZboiIiByO4caJuM4UERGR4zHcOBGXYiAiInI8hhsn41IMREREjsVw42S83w0REZFjMdw4WYVODyG4FAMREZGjMNzY0Z+V1Th/+Uqd+wgBXK3i/W6IiIgcheHGTrYfKUB0SiZSM3+rd98KrjNFRETkMAw3dtKrnQ+q9ALH/ijF5QpdnfteqeKMKSIiIkdhuLGTdi090LOdDwSAnNMX69yX08GJiIgcxy3CTVpaGkJDQ6HVahEVFYWcnJxa97333nshk8lqPIYOHerEii0b3L0tAOCHU3WHm6tVBhgMHFRMRETkCC4PNxkZGUhMTERycjL279+Pvn37IjY2FoWFhRb337hxI/744w/T4/Dhw1AoFBg1apSTK6/JGG4O5l3G1Xq6ntg1RURE5BguDzcLFy7EpEmTkJCQgJ49eyI9PR2enp5Yvny5xf1bt26NwMBA02P79u3w9PR0i3DTLdAbbb010OkNOJB3uc59GW6IiIgcw6XhRqfTYd++fYiJiTFtk8vliImJQXZ2tlWvsWzZMowePRotWrRwVJlWk8lkiOrYGgDww6kLde7LcTdERESO4dJwU1xcDL1ej4CAALPtAQEByM/Pr/f4nJwcHD58GBMnTqx1n8rKSpSWlpo9HCmqUxsAwN7TF6GvY1wNW26IiIgcw+XdUg2xbNky9OnTB5GRkbXuk5KSAl9fX9MjJCTEoTX1CvJBC40CpVercSy/9iBVWWWoM/wQERGRbVwabvz8/KBQKFBQUGC2vaCgAIGBgXUeW15ejnXr1uHpp5+uc7+kpCSUlJSYHnl5eQ2uuy5KhRx3drjWNfW/emZNsfWGiIjI/lwabtRqNcLDw5GZmWnaZjAYkJmZiejo6DqP/fzzz1FZWYknn3yyzv00Gg18fHzMHo5m7Jr6IfdCnetI8U7FRERE9ufybqnExEQsXboUq1atwtGjR/H888+jvLwcCQkJAIDx48cjKSmpxnHLli3DiBEj0KZNG2eXXK87/tISSrkMf5RcRd6l2teauqrjGlNERET2pnR1AXFxcSgqKsKsWbOQn5+PsLAwbNu2zTTI+OzZs5DLzTPY8ePHsXv3bvznP/9xRcn18lQrcXv7lth/9hJ+OHUBf2ntaXG/iiq23BAREdmbTNTVb9IElZaWwtfXFyUlJXbvorpapcdvBX8CAL4+/AcWZZ1EtwBvLBjVt9ZjegR5Q6lweQMaERGRW5Py+c1PVQeJDL02qPh4QRkulte+kGYFBxUTERHZFcONg7Tx0qBrgBeAa/e8qc1V3syPiIjIrhhuHOiOv7QCAPz8e0mt+1Qw3BAREdkVw40D3R7sCwA4dO5yrVPCea8bIiIi+2K4caBugT5QKWS4VFGF3y9bnhJerRfQVXNKOBERkb0w3DiQWilH98BrI7oP1dE1xdYbIiIi+2G4cbA+17umfj5Xe7i5ynBDRERkNww3DnZ7+2vh5vC5klrH3bBbioiIyH4Ybhysa4A31Eo5Sq5U4ezFCov7VDLcEBER2Q3DjYOpFHL0DLo+7qaWrqkqPcMNERGRvTDcOIFp3E0tg4qr9QIGQ7NaBYOIiMhhGG6cwHi/m8PnSmCobdwNW2+IiIjsguHGCbq09YJWJUdZZTXOXCi3uA/H3RAREdkHw40TKG8ad1Nb1xTH3RAREdkHw42T9AluCaD2QcWcDk5ERGQfDDdOYrrfzfkS6C0MHma4ISIisg+GGyfp7O8FD5UC5ZV65BbXHHfDAcVERET2wXDjJAq5DL3aGe93c7nG82y5ISIisg+GGycydk1ZGlQsBAcVExER2QPDjRMZBxUf+aPU4v1u2HpDRETUcAw3dqRSyCGT1f58R78WUCvkqNDp8cflqzWeZ7ghIiJqOIYbO1LIZfBQK+p8vqNfCwDAyaI/azzPQcVEREQNx3BjZ95aZZ3Pd/K/Fm5OFVsIN2y5ISIiajCGGzvz1qjqfL6zvxcA4GQRp4MTERE5AsONnXmoFVAqah94Ywo3hX9C3DKomC03REREDcdw4wBemtq7pjq08YRCLkNZZTWKyirNnqvWCxgs3L2YiIiIrMdw4wA+2tq7plQKOTq09gQAnOSdiomIiOyO4cYBvLTKOqeEGwcVc8YUERGR/THcOEB9U8JvHndzK467ISIiahiGGwepa0q4MdycsjRjiuGGiIioQRhuHKSuKeEd/VpABuBihQ6XynVmzzHcEBERNQzDjYPUNSVcq1IguJUHAODkLTfz4+KZREREDePycJOWlobQ0FBotVpERUUhJyenzv0vX76MKVOmICgoCBqNBl27dsXWrVudVK00dU0Jr+1mfpVsuSEiImoQl4abjIwMJCYmIjk5Gfv370ffvn0RGxuLwsJCi/vrdDo88MADOH36NNavX4/jx49j6dKlCA4OdnLl1qlrSnhn44ypWwYVC8HWGyIiooaoeyEkB1u4cCEmTZqEhIQEAEB6ejq2bNmC5cuXY+bMmTX2X758OS5evIg9e/ZApboWHEJDQ51ZsiTGKeHCwn35TIOKa1ljSqVweaMaERFRo+SyT1CdTod9+/YhJibmRjFyOWJiYpCdnW3xmK+++grR0dGYMmUKAgIC0Lt3b8ybNw96vb7W81RWVqK0tNTs4Sx1TQnv5Hct3BSUVuLPq9Vmz7HlhoiIyHYuCzfFxcXQ6/UICAgw2x4QEID8/HyLx5w6dQrr16+HXq/H1q1b8cYbb+D999/HW2+9Vet5UlJS4Ovra3qEhITY9X3Up7ZxN15aJQJ8NABqDirmjCkiIiLbNaq+D4PBgLZt22LJkiUIDw9HXFwcXnvtNaSnp9d6TFJSEkpKSkyPvLw8J1YMeNpwMz8OKiYiIrKdy8bc+Pn5QaFQoKCgwGx7QUEBAgMDLR4TFBQElUoFheJGYOjRowfy8/Oh0+mgVqtrHKPRaKDRaOxbvAQt1HWPu9lz8gJO3bLGFJdgICIisp3LWm7UajXCw8ORmZlp2mYwGJCZmYno6GiLxwwYMAAnTpyAwXDjw//XX39FUFCQxWDjDuRyGbQqy5e5tjWmOOaGiIjIdi7tlkpMTMTSpUuxatUqHD16FM8//zzKy8tNs6fGjx+PpKQk0/7PP/88Ll68iJdeegm//vortmzZgnnz5mHKlCmuegtWaVHLuBtjt9S5S1dwRXdjUHRVtYCw1NRDRERE9XLpVPC4uDgUFRVh1qxZyM/PR1hYGLZt22YaZHz27FnI5TfyV0hICL755htMnz4dt99+O4KDg/HSSy/hlVdecdVbsIqnWglAV2N7K081WrdQ42K5DrkXytEzyMf0nEEAtdzgmIiIiOogE82siaC0tBS+vr4oKSmBj49P/QfYQbXegKN/lFl8bu7mX7D39CU8M7AThvVtZ9reI8gbSt7rhoiICIC0z29+ejqBUiGHppZxN6Ftro27ybtUYba9WSVOIiIiO2K4cZLapoQH+GgBAIVllWbbm1d7GhERkf0w3DhJC7Xl4U2B18NNQelVs+2CbTdEREQ2YbhxEk+N5ZabttfvUlxYWmk2Q4otN0RERLZhuHESjVIBpYXpT/5eGshl127cd6miygWVERERNS0MN05kqWtKqZCjjde11pubu6bYckNERGQbhhsnqq1rKsC7ZrghIiIi2zDcOFFtg4oDLAwq5oBiIiIi2zDcOJFWJYfcwhW/EW5uTAdntxQREZFtGG6cSCaTWWy9sdxyQ0RERLZguHEyS+NuAq5PB883G1DMeENERGQLhhsns9RyY7yRX/GfldAbroUaRhsiIiLbMNw4mYdKAdktt7tp1UINlUIGg7gWcACOuSEiIrIVw42TyeUyqJXml10uk6Gt9y3jbhhuiIiIbMJw4wIeqprjbtrecq8bTgUnIiKyDcONC2hUNS/7rdPB2S1FRERkG4YbF7DUcnPrdHBmGyIiItsw3LiA1mK4uaVbik03RERENmG4cQGVQg6F3HzKVI1uKadXRURE1DQw3LiIh9q89cYYbi5W6FBZreeYGyIiIhsx3LiI9pZBxT5apWksTmFZJWdLERER2YjhxkVuHVQsk8lqjLshIiIi6RhuXMTyoOKbxt2w4YaIiMgmDDcuolHKayzDYAw3haVXmW2IiIhsxHDjIjKZrMa4m5u7pTigmIiIyDYMNy6kUZp3Td1YX4oDiomIiGzFcONCtU0HZ8sNERGR7RhuXOjWQcXGbqmyymr8WVntipKIiIgaPYYbF9IqzS+/p1oJb60SAHDu8hVXlERERNToMdy4kFIhh0ppeRmG8ww3RERENmG4cTGt0vK4mz8u80Z+REREtnCLcJOWlobQ0FBotVpERUUhJyen1n1XrlwJmUxm9tBqtU6s1r5uHVQceH3czR8lbLkhIiKyhcvDTUZGBhITE5GcnIz9+/ejb9++iI2NRWFhYa3H+Pj44I8//jA9zpw548SK7au2lpvzJWy5ISIisoXLw83ChQsxadIkJCQkoGfPnkhPT4enpyeWL19e6zEymQyBgYGmR0BAgBMrti/NrTfy8zZ2S7HlhoiIyBYuDTc6nQ779u1DTEyMaZtcLkdMTAyys7NrPe7PP/9Ehw4dEBISguHDh+OXX36pdd/KykqUlpaaPdzJrcswGFtu8kuvQvBmN0RERJK5NNwUFxdDr9fXaHkJCAhAfn6+xWO6deuG5cuX48svv8Qnn3wCg8GA/v374/fff7e4f0pKCnx9fU2PkJAQu7+Phri2DMONrqm218fcXK0yoORKlavKIiIiarRc3i0lVXR0NMaPH4+wsDAMGjQIGzduhL+/P/75z39a3D8pKQklJSWmR15enpMrrt/Ng4pVCjnk11tyKqsNLqqIiIio8VK68uR+fn5QKBQoKCgw215QUIDAwECrXkOlUqFfv344ceKExec1Gg00Gk2Da3UktcI8YyoVcuiqDdAx3BAREUnm0pYbtVqN8PBwZGZmmrYZDAZkZmYiOjraqtfQ6/U4dOgQgoKCHFWmw6lvuVOx8nrTTbWBY26IiIikcmnLDQAkJiYiPj4eERERiIyMRGpqKsrLy5GQkAAAGD9+PIKDg5GSkgIAmDt3Lu666y506dIFly9fxnvvvYczZ85g4sSJrnwbDVKj5eZ6uKnSs+WGiIhIKpeHm7i4OBQVFWHWrFnIz89HWFgYtm3bZhpkfPbsWcjlNz78L126hEmTJiE/Px+tWrVCeHg49uzZg549e7rqLTSYSmG+BIPy+vtluCEiIpJOJprZfOPS0lL4+vqipKQEPj4+ri7H5JfzJTBczzJPr9qLwrJKbJoyAGEhLV1aFxERkTuQ8vnd6GZLNVU3d00pjGNu2HJDREQkmeRuqcuXL+OLL77Arl27cObMGVRUVMDf3x/9+vVDbGws+vfv74g6mzyVQo6rVdfCjPJ60NEx3BAREUlmdcvN+fPnMXHiRAQFBeGtt97ClStXEBYWhsGDB6N9+/bYuXMnHnjgAfTs2RMZGRmOrLlJunnGlGm2lL5Z9RgSERHZhdUtN/369UN8fDz27dtX6+DdK1euYNOmTUhNTUVeXh5mzJhht0KbOpWiZrjhgGIiIiLprA43R44cQZs2bercx8PDA2PGjMGYMWNw4cKFBhfXnFhqualiyw0REZFkVndL1RdsGrp/c3fzgGLjmBu23BAREUnXoPvcCCGQlZWFEydOICgoCLGxsVCpVPaqrVm5+V43ptlSBoYbIiIiqSSFmyFDhuDTTz+Fr68vLl68iCFDhiAnJwd+fn64cOECunbtiu+++w7+/v6OqrfJUirkkMsBg+GmbqlqdksRERFJJek+N9u2bUNlZSUA4PXXX0dZWRlOnjyJwsJCnDlzBi1atMCsWbMcUmhzYOyaUl5vxaliyw0REZFkNt/E79tvv0VKSgo6duwIAGjfvj3eeecdfPPNN3YrrrkxzpgyLr/AVcGJiIikkxxuZLJrrQqXLl1C586dzZ7r0qULzp8/b5/KmiHjjClOBSciIrKd5AHFEyZMgEajQVVVFXJzc9GrVy/Tc/n5+WjZsqU962tWVLd2S3EqOBERkWSSwk18fLzp38OHD0dFRYXZ8xs2bEBYWJhdCmuOjGNuFMZVwdktRUREJJmkcLNixYo6n09OToZCoWhQQc2ZsVtKdb1bimtLERERSdeg+9zcqkWLFvZ8uWbHeK8bBe9QTEREZDPJA4qPHDmCyZMno1+/fggKCkJQUBD69euHyZMn48iRI46osdlQKuSQyW6+Q7HexRURERE1PpJabr7++muMGDECd9xxB4YPH46AgAAAQEFBAbZv34477rgDX375JWJjYx1SbHOgUcq5thQREVEDSAo3M2fOxCuvvIK5c+fWeG727NmYPXs2Xn75ZYabBlAp5DfNluKYGyIiIqkkdUv9+uuveOKJJ2p9fsyYMfjtt98aXFRzplbKOeaGiIioASSFm9DQUGzZsqXW57ds2YIOHTo0uKjmTKWQQyXnquBERES2ktQtNXfuXIwdOxZZWVmIiYkxG3OTmZmJbdu2Ye3atQ4ptLlQK25uuWG4ISIikkpSuBk1ahSCg4Px4Ycf4v3330d+fj4AIDAwENHR0cjKykJ0dLRDCm0u1Eo571BMRETUAJLvc9O/f3/079/fEbUQrt3rht1SREREtrN5VXByDOVNs6Wq2XJDREQkmV3DzauvvoqnnnrKni/ZLHmori1hwZYbIiIi6ey6/MK5c+eQl5dnz5dsljSqa5mz2sCWGyIiIqnsGm5WrVplz5drtrRsuSEiIrIZx9y4IYYbIiIi20luuSkuLsby5cuRnZ1tNhW8f//+mDBhAvz9/e1eZHNzY8wNu6WIiIikktRys3fvXnTt2hUffvghfH19cc899+Cee+6Br68vPvzwQ3Tv3h0//vijo2ptNozhptrAlhsiIiKpJLXcvPDCCxg1ahTS09Mhk8nMnhNC4LnnnsMLL7yA7OxsuxbZ3Hior9/nppotN0RERFJJarn56aefMH369BrBBgBkMhmmT5+OgwcPSi4iLS0NoaGh0Gq1iIqKQk5OjlXHrVu3DjKZDCNGjJB8TnemYcsNERGRzSSFm8DAwDqDR05Ojmm9KWtlZGQgMTERycnJ2L9/P/r27YvY2FgUFhbWedzp06cxY8YMDBw4UNL5GgO1wniHYrbcEBERSSWpW2rGjBl45plnsG/fPgwePLjGwplLly7FggULJBWwcOFCTJo0CQkJCQCA9PR0bNmyBcuXL8fMmTMtHqPX6/HEE09gzpw52LVrFy5fvizpnO5OqTDe54YtN0RERFJJCjdTpkyBn58f/v73v2PRokXQ6/UAAIVCgfDwcKxcuRKPP/641a+n0+mwb98+JCUlmbbJ5XLExMTUOW5n7ty5aNu2LZ5++mns2rVLyltoFFRcfoGIiMhmkqeCx8XFIS4uDlVVVSguLgYA+Pn5QaVSST55cXEx9Hp9ja6sgIAAHDt2zOIxu3fvxrJly6we21NZWYnKykrT16WlpZLrdDaVseWG4YaIiEgym2/ip1KpEBQUhKCgIJuCjS3Kysowbtw4LF26FH5+flYdk5KSAl9fX9MjJCTEwVU2nDHc6IWAgUswEBERSWJzuJk/f75prMvN/5bCz88PCoUCBQUFZtsLCgoQGBhYY/+TJ0/i9OnTGDZsGJRKJZRKJf71r3/hq6++glKpxMmTJ2sck5SUhJKSEtOjMax9ZVwVHACqOO6GiIhIEpvDzbx583Dx4sUa/5ZCrVYjPDwcmZmZpm0GgwGZmZmIjo6usX/37t1x6NAhHDx40PR45JFHcN999+HgwYMWW2U0Gg18fHzMHu7OOFsK4IwpIiIiqWxeOFMIYfHfUiUmJiI+Ph4RERGIjIxEamoqysvLTbOnxo8fj+DgYKSkpECr1aJ3795mx7ds2RIAamxvzJTyGy031VxfioiISBK7rgpui7i4OBQVFWHWrFnIz89HWFgYtm3bZhpkfPbsWcjlzWt9T4VcBhkAAUDHcENERCSJy8MNAEydOhVTp061+FxWVladx65cudL+BbmYTCaDUiFDlV5wxhQREZFEzatJpBFRyo13KWbLDRERkRQMN27KOGOK4YaIiEgau4QbSwtpUsOouL4UERGRTewSbhoyW4osM86YYssNERGRNDYPKD5y5AiCg4NN/27Xrp3diqKbu6UYHImIiKSwOdzcfMO8xrCkQWPDAcVERES2salbSqFQoLCwsMb2CxcuQKFQNLgounGXYk4FJyIiksamcFPbGJvKykqo1eoGFUTXcLYUERGRbSR1S3344YcArs2O+vjjj+Hl5WV6Tq/X47vvvkP37t3tW2EzxXBDRERkG0nh5u9//zuAay036enpZl1QarUaoaGhSE9Pt2+FzZSaU8GJiIhsIinc5ObmAgDuu+8+bNy4Ea1atXJIUXRjQHG1gS03REREUtg0W2rnzp32roNuoVJe65bSVTPcEBERSWH35Rfmzp2LXbt22ftlmx2VqeWG3VJERERS2D3crFixArGxsRg2bJi9X7pZubH8AltuiIiIpLD5Jn61yc3NxZUrV9h11UC8QzEREZFtHLIquIeHB4YMGeKIl2422HJDRERkG5vCzezZs2GwMIunpKQEY8aMaXBRBKiut9xUM9wQERFJYlO4WbZsGe6++26cOnXKtC0rKwt9+vTByZMn7VZcc6ZWXvvW6NgtRUREJIlN4ebnn39G+/btERYWhqVLl+Lll1/Ggw8+iHHjxmHPnj32rrFZMt3nhi03REREktg0oLhVq1b47LPP8Oqrr+LZZ5+FUqnE119/jcGDB9u7vmbL2HLDMTdERETS2Dyg+KOPPsIHH3yAMWPGoFOnTnjxxRfx008/2bO2Zk3F2VJEREQ2sSncPPTQQ5gzZw5WrVqFNWvW4MCBA7jnnntw11134d1337V3jc0SZ0sRERHZxqZwo9fr8fPPP2PkyJEArk39Xrx4MdavX29aXJMahuGGiIjINjaNudm+fbvF7UOHDsWhQ4caVBBdw9lSREREtrG65UYI6z5k/fz8bC6GblDJeZ8bIiIiW1gdbnr16oV169ZBp9PVud9vv/2G559/HvPnz29wcc0ZZ0sRERHZxupuqY8++givvPIKJk+ejAceeAARERFo164dtFotLl26hCNHjmD37t345ZdfMHXqVDz//POOrLvJM4650VUz3BAREUlhdbgZPHgwfvzxR+zevRsZGRlYs2YNzpw5gytXrsDPzw/9+vXD+PHj8cQTT6BVq1aOrLlZMIabagPH3BAREUkheUDx3XffjbvvvtsRtdBNjKuCs+WGiIhIGptmS82dO7fO52fNmmVTMXSDmi03RERENrEp3HzxxRdmX1dVVSE3NxdKpRKdO3dmuLEDJe9zQ0REZBObws2BAwdqbCstLcWECRPw6KOPNrgo4vILREREtrJ5balb+fj4YM6cOXjjjTckH5uWlobQ0FBotVpERUUhJyen1n03btyIiIgItGzZEi1atEBYWBhWr17dkNLdEu9QTEREZBu7hRsAKCkpQUlJiaRjMjIykJiYiOTkZOzfvx99+/ZFbGwsCgsLLe7funVrvPbaa8jOzsbPP/+MhIQEJCQk4JtvvrHHW3AbptlSDDdERESS2NQt9eGHH5p9LYTAH3/8gdWrV+Phhx+W9FoLFy7EpEmTkJCQAABIT0/Hli1bsHz5csycObPG/vfee6/Z1y+99BJWrVqF3bt3IzY2VtobcWNKdksRERHZxKZwc+vimHK5HP7+/oiPj0dSUpLVr6PT6bBv3z6zY+RyOWJiYpCdnV3v8UIIfPvttzh+/Djeeecdi/tUVlaisrLS9HVpaanV9bmSmt1SRERENrEp3OTm5trl5MXFxdDr9QgICDDbHhAQgGPHjtV6XElJCYKDg1FZWQmFQoFFixbhgQcesLhvSkoK5syZY5d6nYljboiIiGxj1zE3zuLt7Y2DBw9i7969ePvtt5GYmIisrCyL+yYlJZnGApWUlCAvL8+5xdrI2C3F+9wQERFJY1PLjb34+flBoVCgoKDAbHtBQQECAwNrPU4ul6NLly4AgLCwMBw9ehQpKSk1xuMAgEajgUajsWvdzsBuKSIiItu4tOVGrVYjPDwcmZmZpm0GgwGZmZmIjo62+nUMBoPZuJqmwNRywwHFREREkri05QYAEhMTER8fj4iICERGRiI1NRXl5eWm2VPjx49HcHAwUlJSAFwbQxMREYHOnTujsrISW7duxerVq7F48WJXvg27u3nhTCEEZDKZiysiIiJqHFwebuLi4lBUVIRZs2YhPz8fYWFh2LZtm2mQ8dmzZyGX32hgKi8vx+TJk/H777/Dw8MD3bt3xyeffIK4uDhXvQWHUN30nqsNwnTHYiIiIqqbTAjRrPo9SktL4evri5KSEvj4+Li6nFpV6KrRc9a1GxMemRsLT7XLcygREZHLSPn8bpSzpZoD5U0tN7yRHxERkfUYbtzUzd1QnDFFRERkPYYbNyWTyaCQc8YUERGRVAw3bkxlWl+KLTdERETWYrhxY8ZxNww3RERE1mO4cWMqrgxOREQkGcONG1NyCQYiIiLJGG7cGMfcEBERScdw48aMY264MjgREZH1GG7cmKnlppotN0RERNZiuHFjxsUzdeyWIiIishrDjRszrQzO2VJERERWY7hxY0oOKCYiIpKM4caNqY1TwTmgmIiIyGoMN25MyQHFREREkjHcuDHTmBsDww0REZG1GG7c2I3ZUuyWIiIishbDjRtTm2ZLseWGiIjIWgw3boyzpYiIiKRjuHFjptlS7JYiIiKyGsONG1MruSo4ERGRVAw3box3KCYiIpKO4caNqTjmhoiISDKGGzem5JgbIiIiyRhu3JhKwTE3REREUjHcuDGV/Fq3FO9QTEREZD2GGzemuj5bSlfNbikiIiJrMdy4MSVbboiIiCRjuHFjvM8NERGRdAw3bkwp52wpIiIiqRhu3Bjvc0NERCSdW4SbtLQ0hIaGQqvVIioqCjk5ObXuu3TpUgwcOBCtWrVCq1atEBMTU+f+jRnvUExERCSdy8NNRkYGEhMTkZycjP3796Nv376IjY1FYWGhxf2zsrIwZswY7Ny5E9nZ2QgJCcGDDz6Ic+fOOblyxzOGGx1bboiIiKzm8nCzcOFCTJo0CQkJCejZsyfS09Ph6emJ5cuXW9x/zZo1mDx5MsLCwtC9e3d8/PHHMBgMyMzMdHLljqe83i1VzXBDRERkNZeGG51Oh3379iEmJsa0TS6XIyYmBtnZ2Va9RkVFBaqqqtC6dWuLz1dWVqK0tNTs0Vio2XJDREQkmUvDTXFxMfR6PQICAsy2BwQEID8/36rXeOWVV9CuXTuzgHSzlJQU+Pr6mh4hISENrttZlKYBxRxzQ0REZC2Xd0s1xPz587Fu3Tp88cUX0Gq1FvdJSkpCSUmJ6ZGXl+fkKm13Y0AxW26IiIispXTlyf38/KBQKFBQUGC2vaCgAIGBgXUeu2DBAsyfPx87duzA7bffXut+Go0GGo3GLvU6m3EqOLuliIiIrOfSlhu1Wo3w8HCzwcDGwcHR0dG1Hvfuu+/izTffxLZt2xAREeGMUl2CU8GJiIikc2nLDQAkJiYiPj4eERERiIyMRGpqKsrLy5GQkAAAGD9+PIKDg5GSkgIAeOeddzBr1iysXbsWoaGhprE5Xl5e8PLyctn7cIQbdyhmyw0REZG1XB5u4uLiUFRUhFmzZiE/Px9hYWHYtm2baZDx2bNnIZffaGBavHgxdDodRo4cafY6ycnJmD17tjNLdzi10jgVnC03RERE1pIJIZrVJ2dpaSl8fX1RUlICHx8fV5dTp9PF5bh3QRY81QocmfuQq8shIiJyGSmf3416tlRTp7q+Kni1oVnlTyIiogZhuHFjKjkXziQiIpKK4caNGWdLCQHo2XpDRERkFYYbN2a8QzHA1hsiIiJrMdy4MWPLDcBwQ0REZC2GGzdmHm7YLUVERGQNhhs3ppDLcH1MMdeXIiIishLDjZsztt5wfSkiIiLrMNy4OeOgYt6lmIiIyDoMN27O2HLDAcVERETWYbhxcyrT4plsuSEiIrIGw42bUyl4l2IiIiIpGG7cnFJhXF+K4YaIiMgaDDduzthyo6tmtxQREZE1GG7cnIotN0RERJIw3Lg5zpYiIiKShuHGzd0YUMxuKSIiImsw3Lg5ttwQERFJw3Dj5hhuiIiIpGG4cXPsliIiIpKG4cbNKdlyQ0REJAnDjZtTG6eCs+WGiIjIKgw3bk7J5ReIiIgkYbhxczcGFLPlhoiIyBoMN26OC2cSERFJw3Dj5kzLLzDcEBERWYXhxs0p5de+RTp2SxEREVmF4cbNqZTXuqXYckNERGQdhhs3p5LzPjdERERSMNy4OdNsKQO7pYiIiKzBcOPmTPe5qWbLDRERkTVcHm7S0tIQGhoKrVaLqKgo5OTk1LrvL7/8gr/97W8IDQ2FTCZDamqq8wp1EdMditlyQ0REZBWXhpuMjAwkJiYiOTkZ+/fvR9++fREbG4vCwkKL+1dUVKBTp06YP38+AgMDnVytaxhbbnQcc0NERGQVl4abhQsXYtKkSUhISEDPnj2Rnp4OT09PLF++3OL+d955J9577z2MHj0aGo3GydW6Bu9zQ0REJI3Lwo1Op8O+ffsQExNzoxi5HDExMcjOznZVWW5HzeUXiIiIJFG66sTFxcXQ6/UICAgw2x4QEIBjx47Z7TyVlZWorKw0fV1aWmq313YGLpxJREQkjcsHFDtaSkoKfH19TY+QkBBXlyTJjYUzGW6IiIis4bJw4+fnB4VCgYKCArPtBQUFdh0snJSUhJKSEtMjLy/Pbq/tDMaFM6vZLUVERGQVl4UbtVqN8PBwZGZmmrYZDAZkZmYiOjrabufRaDTw8fExezQmbLkhIiKSxmVjbgAgMTER8fHxiIiIQGRkJFJTU1FeXo6EhAQAwPjx4xEcHIyUlBQA1wYhHzlyxPTvc+fO4eDBg/Dy8kKXLl1c9j4cSckBxURERJK4NNzExcWhqKgIs2bNQn5+PsLCwrBt2zbTIOOzZ89CLr/RuHT+/Hn069fP9PWCBQuwYMECDBo0CFlZWc4u3ylUHFBMREQkiUvDDQBMnToVU6dOtfjcrYElNDQUQjSvFgwV71BMREQkSZOfLdXYGcONjmtLERERWYXhxs0p5eyWIiIikoLhxs2pleyWIiIikoLhxs2ZWm7YLUVERGQVhhs3Z7rPjYHhhoiIyBoMN25OxfvcEBERScJw4+aM97nRGwQMHHdDRERUL4YbN2e8QzHArikiIiJrMNy4OfVN4YaLZxIREdWP4cbNKa93SwG81w0REZE1GG7cnHEqOMBBxURERNZguHFzMpmMi2cSERFJwHDTCJgWz2TLDRERUb0YbhoBY9eUji03RERE9WK4aQRurC/FcENERFQfhptGQCm/fpfianZLERER1YfhphFQKa8PKGbLDRERUb0YbhoBlanlhuGGiIioPgw3jYBpthTXliIiIqoXw00jYOyW4mwpIiKi+jHcNALGAcW8zw0REVH9GG4aAePimbxDMRERUf0YbhoBJZdfICIishrDTSOgMrXcsFuKiIioPgw3jYBx4cxqttwQERHVi+GmEVBxzA0REZHVGG4aASW7pYiIiKzGcNMIqDigmIiIyGoMN42AafkFhhsiIqJ6Mdw0AqaFM9ktRUREVC+Gm0ZAyZYbIiIiqzHcNAJqJRfOJCIispZbhJu0tDSEhoZCq9UiKioKOTk5de7/+eefo3v37tBqtejTpw+2bt3qpEpdQym/vnBmNVtuiIiI6uPycJORkYHExEQkJydj//796Nu3L2JjY1FYWGhx/z179mDMmDF4+umnceDAAYwYMQIjRozA4cOHnVy58xjvc1NtYLghIiKqj0wI4dK+jqioKNx55534xz/+AQAwGAwICQnBCy+8gJkzZ9bYPy4uDuXl5di8ebNp21133YWwsDCkp6fXe77S0lL4+vqipKQEPj4+9nsjDvSPb3/Dgv/8ir/eHoSZD3d3dTlERER1UivlaOuttetrSvn8Vtr1zBLpdDrs27cPSUlJpm1yuRwxMTHIzs62eEx2djYSExPNtsXGxmLTpk0W96+srERlZaXp69LS0oYX7mTGlpvNP/+BzT//4eJqiIiI6nbHX1pi4+QBLju/S8NNcXEx9Ho9AgICzLYHBATg2LFjFo/Jz8+3uH9+fr7F/VNSUjBnzhz7FOwiA7r4IcBHg8sVVa4uhYiIqF7GP8pdxaXhxhmSkpLMWnpKS0sREhLiwoqk6x3six9ejXF1GURERI2CS8ONn58fFAoFCgoKzLYXFBQgMDDQ4jGBgYGS9tdoNNBoNPYpmIiIiNyeS9uN1Go1wsPDkZmZadpmMBiQmZmJ6Ohoi8dER0eb7Q8A27dvr3V/IiIial5c3i2VmJiI+Ph4REREIDIyEqmpqSgvL0dCQgIAYPz48QgODkZKSgoA4KWXXsKgQYPw/vvvY+jQoVi3bh1+/PFHLFmyxJVvg4iIiNyEy8NNXFwcioqKMGvWLOTn5yMsLAzbtm0zDRo+e/Ys5PIbDUz9+/fH2rVr8frrr+PVV1/Fbbfdhk2bNqF3796uegtERETkRlx+nxtna4z3uSEiImrupHx+u/wOxURERET2xHBDRERETQrDDRERETUpDDdERETUpDDcEBERUZPCcENERERNCsMNERERNSkMN0RERNSkMNwQERFRk+Ly5ReczXhD5tLSUhdXQkRERNYyfm5bs7BCsws3ZWVlAICQkBAXV0JERERSlZWVwdfXt859mt3aUgaDAefPn4e3tzdkMpldX7u0tBQhISHIy8vjulUOxOvsHLzOzsHr7Dy81s7hqOsshEBZWRnatWtntqC2Jc2u5UYul6N9+/YOPYePjw9/cZyA19k5eJ2dg9fZeXitncMR17m+FhsjDigmIiKiJoXhhoiIiJoUhhs70mg0SE5OhkajcXUpTRqvs3PwOjsHr7Pz8Fo7hztc52Y3oJiIiIiaNrbcEBERUZPCcENERERNCsMNERERNSkMN0RERNSkMNxIlJaWhtDQUGi1WkRFRSEnJ6fO/T///HN0794dWq0Wffr0wdatW51UaeMm5TovXboUAwcORKtWrdCqVSvExMTU+32ha6T+PButW7cOMpkMI0aMcGyBTYTU63z58mVMmTIFQUFB0Gg06Nq1K//vsILU65yamopu3brBw8MDISEhmD59Oq5eveqkahun7777DsOGDUO7du0gk8mwadOmeo/JysrCHXfcAY1Ggy5dumDlypUOrxOCrLZu3TqhVqvF8uXLxS+//CImTZokWrZsKQoKCizu//333wuFQiHeffddceTIEfH6668LlUolDh065OTKGxep13ns2LEiLS1NHDhwQBw9elRMmDBB+Pr6it9//93JlTcuUq+zUW5urggODhYDBw4Uw4cPd06xjZjU61xZWSkiIiLEkCFDxO7du0Vubq7IysoSBw8edHLljYvU67xmzRqh0WjEmjVrRG5urvjmm29EUFCQmD59upMrb1y2bt0qXnvtNbFx40YBQHzxxRd17n/q1Cnh6ekpEhMTxZEjR8RHH30kFAqF2LZtm0PrZLiRIDIyUkyZMsX0tV6vF+3atRMpKSkW93/88cfF0KFDzbZFRUWJZ5991qF1NnZSr/Otqqurhbe3t1i1apWjSmwSbLnO1dXVon///uLjjz8W8fHxDDdWkHqdFy9eLDp16iR0Op2zSmwSpF7nKVOmiPvvv99sW2JiohgwYIBD62xKrAk3/+///T/Rq1cvs21xcXEiNjbWgZUJwW4pK+l0Ouzbtw8xMTGmbXK5HDExMcjOzrZ4THZ2ttn+ABAbG1vr/mTbdb5VRUUFqqqq0Lp1a0eV2ejZep3nzp2Ltm3b4umnn3ZGmY2eLdf5q6++QnR0NKZMmYKAgAD07t0b8+bNg16vd1bZjY4t17l///7Yt2+fqevq1KlT2Lp1K4YMGeKUmpsLV30ONruFM21VXFwMvV6PgIAAs+0BAQE4duyYxWPy8/Mt7p+fn++wOhs7W67zrV555RW0a9euxi8U3WDLdd69ezeWLVuGgwcPOqHCpsGW63zq1Cl8++23eOKJJ7B161acOHECkydPRlVVFZKTk51RdqNjy3UeO3YsiouLcffdd0MIgerqajz33HN49dVXnVFys1Hb52BpaSmuXLkCDw8Ph5yXLTfUpMyfPx/r1q3DF198Aa1W6+pymoyysjKMGzcOS5cuhZ+fn6vLadIMBgPatm2LJUuWIDw8HHFxcXjttdeQnp7u6tKalKysLMybNw+LFi3C/v37sXHjRmzZsgVvvvmmq0sjO2DLjZX8/PygUChQUFBgtr2goACBgYEWjwkMDJS0P9l2nY0WLFiA+fPnY8eOHbj99tsdWWajJ/U6nzx5EqdPn8awYcNM2wwGAwBAqVTi+PHj6Ny5s2OLboRs+XkOCgqCSqWCQqEwbevRowfy8/Oh0+mgVqsdWnNjZMt1fuONNzBu3DhMnDgRANCnTx+Ul5fjmWeewWuvvQa5nH/720Ntn4M+Pj4Oa7UB2HJjNbVajfDwcGRmZpq2GQwGZGZmIjo62uIx0dHRZvsDwPbt22vdn2y7zgDw7rvv4s0338S2bdsQERHhjFIbNanXuXv37jh06BAOHjxoejzyyCO47777cPDgQYSEhDiz/EbDlp/nAQMG4MSJE6bwCAC//vorgoKCGGxqYct1rqioqBFgjIFScMlFu3HZ56BDhys3MevWrRMajUasXLlSHDlyRDzzzDOiZcuWIj8/XwghxLhx48TMmTNN+3///fdCqVSKBQsWiKNHj4rk5GROBbeC1Os8f/58oVarxfr168Uff/xhepSVlbnqLTQKUq/zrThbyjpSr/PZs2eFt7e3mDp1qjh+/LjYvHmzaNu2rXjrrbdc9RYaBanXOTk5WXh7e4tPP/1UnDp1SvznP/8RnTt3Fo8//rir3kKjUFZWJg4cOCAOHDggAIiFCxeKAwcOiDNnzgghhJg5c6YYN26caX/jVPCXX35ZHD16VKSlpXEquDv66KOPxF/+8hehVqtFZGSk+N///md6btCgQSI+Pt5s/88++0x07dpVqNVq0atXL7FlyxYnV9w4SbnOHTp0EABqPJKTk51feCMj9ef5Zgw31pN6nffs2SOioqKERqMRnTp1Em+//baorq52ctWNj5TrXFVVJWbPni06d+4stFqtCAkJEZMnTxaXLl1yfuGNyM6dOy3+f2u8tvHx8WLQoEE1jgkLCxNqtVp06tRJrFixwuF1yoRg+xsRERE1HRxzQ0RERE0Kww0RERE1KQw3RERE1KQw3BAREVGTwnBDRERETQrDDRERETUpDDdERETUpDDcEBERUZPCcENERERNCsMNERERNSkMN0TU6BUVFSEwMBDz5s0zbduzZw/UanWNFYmJqOnj2lJE1CRs3boVI0aMwJ49e9CtWzeEhYVh+PDhWLhwoatLIyInY7ghoiZjypQp2LFjByIiInDo0CHs3bsXGo3G1WURkZMx3BBRk3HlyhX07t0beXl52LdvH/r06ePqkojIBTjmhoiajJMnT+L8+fMwGAw4ffq0q8shIhdhyw0RNQk6nQ6RkZEICwtDt27dkJqaikOHDqFt27auLo2InIzhhoiahJdffhnr16/HTz/9BC8vLwwaNAi+vr7YvHmzq0sjIidjtxQRNXpZWVlITU3F6tWr4ePjA7lcjtWrV2PXrl1YvHixq8sjIidjyw0RERE1KWy5ISIioiaF4YaIiIiaFIYbIiIialIYboiIiKhJYbghIiKiJoXhhoiIiJoUhhsiIiJqUhhuiIiIqElhuCEiIqImheGGiIiImhSGGyIiImpSGG6IiIioSfn/GKswCQKkH+8AAAAASUVORK5CYII=",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "t_idx = 1\n",
    "with torch.no_grad():\n",
    "    plt.ylabel(\"u(x,t={t:.2f})\".format(t=t[slice(*tpred)][t_idx]))\n",
    "    plt.title(\"Plotting mean and variance for {dataset}\".format(k = np.mean(dataset_params), dataset = dataset))\n",
    "    plt.xlabel(\"x\")\n",
    "    mu =  mu_true[:,t_idx,:].squeeze(-1)\n",
    "    plt.plot(grid, mu)\n",
    "    std = torch.sqrt(var_true[:,t_idx,:]).squeeze(-1)\n",
    "    print(std)\n",
    "    plt.fill_between(grid, mu + 3*std, mu - 3*std, alpha = 0.2)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ac0b8967",
   "metadata": {},
   "source": [
    "## Running VarianceNO out of the box"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 175,
   "id": "6fd587e6-4cf8-4cc4-92bd-ac09c7be790a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<function datasets.PME_1D.get_mass_rhs_func.<locals>.mass_rhs_func(inputs)>"
      ]
     },
     "execution_count": 175,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataset_class.get_mass_rhs_func(x=x_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 176,
   "id": "385681cf-a2fa-4ba9-8a9b-462fabb915ee",
   "metadata": {},
   "outputs": [],
   "source": [
    "x_train_reshaped = rearrange(x_train, \" nf nx nt 1 -> nf (nx nt)\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 177,
   "id": "b339b6e5",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "device(type='cpu')"
      ]
     },
     "execution_count": 177,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x_train.device"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 178,
   "id": "f3073741-14be-4b35-b8cc-30ed22fbc4bc",
   "metadata": {},
   "outputs": [],
   "source": [
    "mass_rhs_func = dataset_class.get_mass_rhs_func(x=x_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 179,
   "id": "b202791b-cc81-4a0a-b82b-db3c5a831a77",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<function datasets.PME_1D.get_mass_rhs_func.<locals>.mass_rhs_func(inputs)>"
      ]
     },
     "execution_count": 179,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "mass_rhs_func"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 180,
   "id": "d71ec298-4bbc-4186-a938-8675f0526890",
   "metadata": {},
   "outputs": [],
   "source": [
    "import time"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 181,
   "id": "87e15df0-69a3-4d0e-82f0-f04540d1a1e8",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([0.0000, 0.0101, 0.0202, 0.0303, 0.0404, 0.0505, 0.0606, 0.0707, 0.0808,\n",
       "        0.0909, 0.1010, 0.1111, 0.1212, 0.1313, 0.1414, 0.1515, 0.1616, 0.1717,\n",
       "        0.1818, 0.1919, 0.2020, 0.2121, 0.2222, 0.2323, 0.2424, 0.2525, 0.2626,\n",
       "        0.2727, 0.2828, 0.2929, 0.3030, 0.3131, 0.3232, 0.3333, 0.3434, 0.3535,\n",
       "        0.3636, 0.3737, 0.3838, 0.3939, 0.4040, 0.4141, 0.4242, 0.4343, 0.4444,\n",
       "        0.4545, 0.4646, 0.4747, 0.4848, 0.4949, 0.5051, 0.5152, 0.5253, 0.5354,\n",
       "        0.5455, 0.5556, 0.5657, 0.5758, 0.5859, 0.5960, 0.6061, 0.6162, 0.6263,\n",
       "        0.6364, 0.6465, 0.6566, 0.6667, 0.6768, 0.6869, 0.6970, 0.7071, 0.7172,\n",
       "        0.7273, 0.7374, 0.7475, 0.7576, 0.7677, 0.7778, 0.7879, 0.7980, 0.8081,\n",
       "        0.8182, 0.8283, 0.8384, 0.8485, 0.8586, 0.8687, 0.8788, 0.8889, 0.8990,\n",
       "        0.9091, 0.9192, 0.9293, 0.9394, 0.9495, 0.9596, 0.9697, 0.9798, 0.9899,\n",
       "        1.0000])"
      ]
     },
     "execution_count": 181,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "grid"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 182,
   "id": "2165975d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[0, -1, 5]"
      ]
     },
     "execution_count": 182,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tpred"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 183,
   "id": "1ad2c532",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([ 0, -1,  5])"
      ]
     },
     "execution_count": 183,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.tensor(tpred)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 184,
   "id": "18a1c208-a4df-4f27-8dad-f2466164e855",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([160, 100, 20, 1])"
      ]
     },
     "execution_count": 184,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x_train.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 185,
   "id": "6ff40f4d-d8b7-45f7-8e3a-8115564540df",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 0: Train loss=272.487983, Validation loss=-168.229724 (saved)\n",
      "Epoch 1: Train loss=-606.056613, Validation loss=-1254.716699 (saved)\n",
      "Epoch 2: Train loss=-1995.919812, Validation loss=-3431.097266 (saved)\n",
      "Epoch 3: Train loss=-5110.957812, Validation loss=-3500.480078 (saved)\n",
      "Epoch 4: Train loss=-6405.712402, Validation loss=-6985.716406 (saved)\n",
      "Epoch 5: Train loss=-7332.890625, Validation loss=-8184.097656 (saved)\n",
      "Epoch 6: Train loss=-8431.302246, Validation loss=-9364.947266 (saved)\n",
      "Epoch 7: Train loss=-9620.480859, Validation loss=-10476.521875 (saved)\n",
      "Epoch 8: Train loss=-10602.453516, Validation loss=-10862.499219 (saved)\n",
      "Epoch 9: Train loss=-8528.451587, Validation loss=-8474.004688 \n",
      "Epoch 10: Train loss=-9064.690137, Validation loss=-9203.159375 \n",
      "Epoch 11: Train loss=-9543.523340, Validation loss=-10537.481641 \n",
      "Epoch 12: Train loss=-11459.803711, Validation loss=-12819.908984 (saved)\n",
      "Epoch 13: Train loss=7138.775488, Validation loss=-3128.430957 \n",
      "Epoch 14: Train loss=-4337.447217, Validation loss=-4522.438867 \n",
      "Epoch 15: Train loss=-4100.837793, Validation loss=-3707.613672 \n",
      "Epoch 16: Train loss=-3630.627930, Validation loss=-3610.627148 \n",
      "Epoch 17: Train loss=-3686.062256, Validation loss=-3794.612695 \n",
      "Epoch 18: Train loss=-3893.757324, Validation loss=-4010.686914 \n",
      "Epoch 19: Train loss=-4107.118018, Validation loss=-4230.917578 \n",
      "Epoch 20: Train loss=-4337.470898, Validation loss=-4473.541992 \n",
      "Epoch 21: Train loss=-4585.358252, Validation loss=-4728.411328 \n",
      "Epoch 22: Train loss=-4851.073926, Validation loss=-5008.805859 \n",
      "Epoch 23: Train loss=-5144.979395, Validation loss=-5321.536328 \n",
      "Epoch 24: Train loss=-5473.653516, Validation loss=-5672.993555 \n",
      "Epoch 25: Train loss=-5847.064209, Validation loss=-6077.934766 \n",
      "Epoch 26: Train loss=-6285.072803, Validation loss=-6567.612500 \n",
      "Epoch 27: Train loss=-6831.161914, Validation loss=-7203.371484 \n",
      "Epoch 28: Train loss=-7579.189551, Validation loss=-8133.332422 \n",
      "Epoch 29: Train loss=-8743.915820, Validation loss=-9697.750000 \n",
      "Epoch 30: Train loss=-9956.167578, Validation loss=-10385.628906 \n",
      "Epoch 31: Train loss=-10993.952344, Validation loss=-11788.531641 \n",
      "Epoch 32: Train loss=-11050.431543, Validation loss=-10982.014063 \n",
      "Epoch 33: Train loss=-10747.816992, Validation loss=-11993.269531 \n",
      "Epoch 34: Train loss=-11426.482129, Validation loss=-11232.575000 \n",
      "Epoch 35: Train loss=-12027.412891, Validation loss=-12408.537500 \n",
      "Epoch 36: Train loss=-8554.151172, Validation loss=13144.581250 \n",
      "Epoch 37: Train loss=-3434.389014, Validation loss=-7308.229688 \n",
      "Epoch 38: Train loss=-7200.710156, Validation loss=-6535.810547 \n",
      "Epoch 39: Train loss=-6411.487598, Validation loss=-6503.701367 \n",
      "Epoch 40: Train loss=-6721.341016, Validation loss=-6951.868359 \n",
      "Epoch 41: Train loss=-7102.125195, Validation loss=-7332.178125 \n",
      "Epoch 42: Train loss=-7544.917578, Validation loss=-7830.041016 \n",
      "Epoch 43: Train loss=-8077.934375, Validation loss=-8394.142578 \n",
      "Epoch 44: Train loss=-8662.465332, Validation loss=-9020.814844 \n",
      "Epoch 45: Train loss=-9314.847168, Validation loss=-9713.795703 \n",
      "Epoch 46: Train loss=-10039.071484, Validation loss=-10472.363281 \n",
      "Epoch 47: Train loss=-10837.636816, Validation loss=-11335.392187 \n",
      "Epoch 48: Train loss=-11708.836230, Validation loss=-12240.361328 \n",
      "Epoch 49: Train loss=-12623.476074, Validation loss=-13071.939062 (saved)\n",
      "Epoch 50: Train loss=-13201.505371, Validation loss=-13438.547656 (saved)\n",
      "Epoch 51: Train loss=-13501.659180, Validation loss=-13124.583594 \n",
      "Epoch 52: Train loss=-13110.004785, Validation loss=-11194.341406 \n",
      "Epoch 53: Train loss=-13288.572266, Validation loss=-14223.976562 (saved)\n",
      "Epoch 54: Train loss=-14170.277734, Validation loss=-14361.285938 (saved)\n",
      "Epoch 55: Train loss=-14276.079687, Validation loss=-14362.032812 (saved)\n",
      "Epoch 56: Train loss=-14311.134570, Validation loss=-15064.932812 (saved)\n",
      "Epoch 57: Train loss=-12808.315527, Validation loss=-15231.615625 (saved)\n",
      "Epoch 58: Train loss=-13866.802051, Validation loss=-13642.684375 \n",
      "Epoch 59: Train loss=-14300.083008, Validation loss=-14766.426563 \n",
      "Epoch 60: Train loss=-14757.204883, Validation loss=-14667.460156 \n",
      "Epoch 61: Train loss=-15120.773438, Validation loss=-15467.184375 (saved)\n",
      "Epoch 62: Train loss=-15667.351953, Validation loss=-16026.932812 (saved)\n",
      "Epoch 63: Train loss=10922.998145, Validation loss=-3337.109375 \n",
      "Epoch 64: Train loss=-5310.198584, Validation loss=-8839.742188 \n",
      "Epoch 65: Train loss=-8371.100000, Validation loss=-7869.076172 \n",
      "Epoch 66: Train loss=-7886.617773, Validation loss=-7912.613281 \n",
      "Epoch 67: Train loss=-7906.813281, Validation loss=-7956.882812 \n",
      "Epoch 68: Train loss=-8052.875293, Validation loss=-8158.696875 \n",
      "Epoch 69: Train loss=-8253.777637, Validation loss=-8384.688281 \n",
      "Epoch 70: Train loss=-8498.786914, Validation loss=-8633.307812 \n",
      "Epoch 71: Train loss=-8743.191113, Validation loss=-8880.424219 \n",
      "Epoch 72: Train loss=-8989.935645, Validation loss=-9127.518750 \n",
      "Epoch 73: Train loss=-9238.274805, Validation loss=-9376.682422 \n",
      "Epoch 74: Train loss=-9488.313770, Validation loss=-9630.006250 \n",
      "Epoch 75: Train loss=-9742.651465, Validation loss=-9887.348438 \n",
      "Epoch 76: Train loss=-10001.233887, Validation loss=-10149.949609 \n",
      "Epoch 77: Train loss=-10266.716113, Validation loss=-10419.076563 \n",
      "Epoch 78: Train loss=-10538.421582, Validation loss=-10695.588672 \n",
      "Epoch 79: Train loss=-10817.170313, Validation loss=-10980.526562 \n",
      "Epoch 80: Train loss=-11104.221777, Validation loss=-11274.005469 \n",
      "Epoch 81: Train loss=-11401.874219, Validation loss=-11577.271094 \n",
      "Epoch 82: Train loss=-11708.280371, Validation loss=-11889.263281 \n",
      "Epoch 83: Train loss=-12016.233887, Validation loss=-12209.510156 \n",
      "Epoch 84: Train loss=-12334.414746, Validation loss=-12520.421484 \n",
      "Epoch 85: Train loss=-12623.495508, Validation loss=-12895.428906 \n",
      "Epoch 86: Train loss=-13019.804199, Validation loss=-13030.077344 \n",
      "Epoch 87: Train loss=-13335.845117, Validation loss=-13625.182812 \n",
      "Epoch 88: Train loss=-13709.363867, Validation loss=-13818.139844 \n",
      "Epoch 89: Train loss=-13877.669141, Validation loss=-14161.303125 \n",
      "Epoch 90: Train loss=-14389.829883, Validation loss=-14503.699219 \n",
      "Epoch 91: Train loss=-14843.312109, Validation loss=-14082.731250 \n",
      "Epoch 92: Train loss=3040.453687, Validation loss=1905.210938 \n",
      "Epoch 93: Train loss=-7377.025806, Validation loss=-9408.316406 \n",
      "Epoch 94: Train loss=-10594.128223, Validation loss=-10144.609375 \n",
      "Epoch 95: Train loss=-10332.015234, Validation loss=-10693.942969 \n",
      "Epoch 96: Train loss=-10673.447168, Validation loss=-10708.057812 \n",
      "Epoch 97: Train loss=-10888.056348, Validation loss=-11028.971484 \n",
      "Epoch 98: Train loss=-11140.408203, Validation loss=-11317.971875 \n",
      "Epoch 99: Train loss=-11448.478809, Validation loss=-11609.708984 \n",
      "Epoch 100: Train loss=-11689.992578, Validation loss=-11778.503125 \n",
      "Epoch 101: Train loss=-11851.716309, Validation loss=-11939.353125 \n",
      "Epoch 102: Train loss=-12016.497070, Validation loss=-12106.464844 \n",
      "Epoch 103: Train loss=-12185.225098, Validation loss=-12277.626953 \n",
      "Epoch 104: Train loss=-12357.399414, Validation loss=-12451.181250 \n",
      "Epoch 105: Train loss=-12532.483691, Validation loss=-12628.372656 \n",
      "Epoch 106: Train loss=-12711.674805, Validation loss=-12809.819141 \n",
      "Epoch 107: Train loss=-12893.833887, Validation loss=-12994.986719 \n",
      "Epoch 108: Train loss=-13080.617871, Validation loss=-13183.811719 \n",
      "Epoch 109: Train loss=-13272.103516, Validation loss=-13377.925781 \n",
      "Epoch 110: Train loss=-13467.733594, Validation loss=-13576.053906 \n",
      "Epoch 111: Train loss=-13667.006641, Validation loss=-13778.800000 \n",
      "Epoch 112: Train loss=-13871.270703, Validation loss=-13986.058594 \n",
      "Epoch 113: Train loss=-14080.368359, Validation loss=-14198.025781 \n",
      "Epoch 114: Train loss=-14291.757617, Validation loss=-14412.373438 \n",
      "Epoch 115: Train loss=-14512.148047, Validation loss=-14635.868750 \n",
      "Epoch 116: Train loss=-14733.604102, Validation loss=-14860.418750 \n",
      "Epoch 117: Train loss=-14960.595117, Validation loss=-15091.573437 \n",
      "Epoch 118: Train loss=-15188.315039, Validation loss=-15324.870312 \n",
      "Epoch 119: Train loss=-15427.853516, Validation loss=-15562.955469 \n",
      "Epoch 120: Train loss=-15666.910547, Validation loss=-15807.421875 \n",
      "Epoch 121: Train loss=-15910.644531, Validation loss=-16049.942188 (saved)\n",
      "Epoch 122: Train loss=-16150.961914, Validation loss=-16300.875000 (saved)\n",
      "Epoch 123: Train loss=-16404.142969, Validation loss=-16550.232031 (saved)\n",
      "Epoch 124: Train loss=-16660.355859, Validation loss=-16765.385937 (saved)\n",
      "Epoch 125: Train loss=-16895.543945, Validation loss=-17072.369531 (saved)\n",
      "Epoch 126: Train loss=-17143.720313, Validation loss=-17342.146875 (saved)\n",
      "Epoch 127: Train loss=-17396.685156, Validation loss=-17579.446094 (saved)\n",
      "Epoch 128: Train loss=-17648.413086, Validation loss=-17857.730469 (saved)\n",
      "Epoch 129: Train loss=-17848.926953, Validation loss=-17347.375000 \n",
      "Epoch 130: Train loss=-11047.843384, Validation loss=-12957.706250 \n",
      "Epoch 131: Train loss=-15823.540039, Validation loss=-16719.956250 \n",
      "Epoch 132: Train loss=-16647.573242, Validation loss=-16216.871875 \n",
      "Epoch 133: Train loss=-16842.590039, Validation loss=-16832.996875 \n",
      "Epoch 134: Train loss=-17147.360547, Validation loss=-17408.158594 \n",
      "Epoch 135: Train loss=-17545.982812, Validation loss=-17715.127344 \n",
      "Epoch 136: Train loss=-17820.684766, Validation loss=-17995.739063 (saved)\n",
      "Epoch 137: Train loss=-18112.908203, Validation loss=-18314.000000 (saved)\n",
      "Epoch 138: Train loss=-18433.591992, Validation loss=-18615.003125 (saved)\n",
      "Epoch 139: Train loss=-18744.227930, Validation loss=-18842.711719 (saved)\n",
      "Epoch 140: Train loss=1030.882910, Validation loss=-5449.527344 \n",
      "Epoch 141: Train loss=-4222.343408, Validation loss=-10522.019922 \n",
      "Epoch 142: Train loss=-13036.724902, Validation loss=-13154.012500 \n",
      "Epoch 143: Train loss=-13871.311719, Validation loss=-13697.561719 \n",
      "Epoch 144: Train loss=-14141.809375, Validation loss=-14128.412500 \n",
      "Epoch 145: Train loss=-14294.444531, Validation loss=-14405.170313 \n",
      "Epoch 146: Train loss=-14467.003906, Validation loss=-14590.973438 \n",
      "Epoch 147: Train loss=-14649.669531, Validation loss=-14755.490625 \n",
      "Epoch 148: Train loss=-14839.317188, Validation loss=-14937.854687 \n",
      "Epoch 149: Train loss=-15037.360937, Validation loss=-15143.760937 \n",
      "Epoch 150: Train loss=-15197.601758, Validation loss=-15247.473438 \n",
      "Epoch 151: Train loss=-15301.434375, Validation loss=-15352.215625 \n",
      "Epoch 152: Train loss=-15406.772461, Validation loss=-15458.615625 \n",
      "Epoch 153: Train loss=-15513.899219, Validation loss=-15566.352344 \n",
      "Epoch 154: Train loss=-15622.353320, Validation loss=-15675.807031 \n",
      "Epoch 155: Train loss=-15732.398633, Validation loss=-15786.597656 \n",
      "Epoch 156: Train loss=-15843.898633, Validation loss=-15898.878125 \n",
      "Epoch 157: Train loss=-15956.752344, Validation loss=-16012.751562 \n",
      "Epoch 158: Train loss=-16071.325391, Validation loss=-16128.125000 \n",
      "Epoch 159: Train loss=-16187.336914, Validation loss=-16245.014063 \n",
      "Epoch 160: Train loss=-16304.803320, Validation loss=-16363.080469 \n",
      "Epoch 161: Train loss=-16423.754883, Validation loss=-16482.567188 \n",
      "Epoch 162: Train loss=-16543.738867, Validation loss=-16604.834375 \n",
      "Epoch 163: Train loss=-16665.637891, Validation loss=-16725.892188 \n",
      "Epoch 164: Train loss=-16787.530078, Validation loss=-16848.261719 \n",
      "Epoch 165: Train loss=-16912.426172, Validation loss=-16977.151562 \n",
      "Epoch 166: Train loss=-17039.909180, Validation loss=-17104.929688 \n",
      "Epoch 167: Train loss=-17165.846484, Validation loss=-17233.606250 \n",
      "Epoch 168: Train loss=-17296.765039, Validation loss=-17363.991406 \n",
      "Epoch 169: Train loss=-17428.302734, Validation loss=-17494.372656 \n",
      "Epoch 170: Train loss=-17559.469727, Validation loss=-17621.400000 \n",
      "Epoch 171: Train loss=-17692.444922, Validation loss=-17755.108594 \n",
      "Epoch 172: Train loss=-17827.482812, Validation loss=-17895.064063 \n",
      "Epoch 173: Train loss=-17961.942383, Validation loss=-18036.674219 \n",
      "Epoch 174: Train loss=-18099.252539, Validation loss=-18156.259375 \n",
      "Epoch 175: Train loss=-18230.120703, Validation loss=-18314.835156 \n",
      "Epoch 176: Train loss=-18376.734570, Validation loss=-18454.175000 \n",
      "Epoch 177: Train loss=-18518.079102, Validation loss=-18579.442969 \n",
      "Epoch 178: Train loss=-18653.245508, Validation loss=-18742.688281 \n",
      "Epoch 179: Train loss=-18798.835547, Validation loss=-18888.321094 (saved)\n",
      "Epoch 180: Train loss=-18935.456055, Validation loss=-19021.228125 (saved)\n",
      "Epoch 181: Train loss=-19094.386133, Validation loss=-19174.017188 (saved)\n",
      "Epoch 182: Train loss=-19237.783984, Validation loss=-19259.141406 (saved)\n",
      "Epoch 183: Train loss=-19301.355273, Validation loss=-19267.747656 (saved)\n",
      "Epoch 184: Train loss=-19463.208008, Validation loss=-19529.611719 (saved)\n",
      "Epoch 185: Train loss=-19624.051367, Validation loss=-19557.654687 (saved)\n",
      "Epoch 186: Train loss=-19746.331641, Validation loss=-19855.512500 (saved)\n",
      "Epoch 187: Train loss=-19859.097461, Validation loss=-19905.860938 (saved)\n",
      "Epoch 188: Train loss=-19896.063281, Validation loss=-19948.231250 (saved)\n",
      "Epoch 189: Train loss=-20013.189063, Validation loss=-20215.527344 (saved)\n",
      "Epoch 190: Train loss=-19982.592383, Validation loss=-18946.873437 \n",
      "Epoch 191: Train loss=-19728.873242, Validation loss=-20259.582031 (saved)\n",
      "Epoch 192: Train loss=-19896.952930, Validation loss=-20470.478125 (saved)\n",
      "Epoch 193: Train loss=-19907.926172, Validation loss=-15293.784375 \n",
      "Epoch 194: Train loss=-16889.740625, Validation loss=-20339.960156 \n",
      "Epoch 195: Train loss=-19785.738672, Validation loss=-19590.864063 \n",
      "Epoch 196: Train loss=-19848.442969, Validation loss=-19958.854687 \n",
      "Epoch 197: Train loss=-20181.503125, Validation loss=-20568.504688 (saved)\n",
      "Epoch 198: Train loss=-20398.161328, Validation loss=-20353.121875 \n",
      "Epoch 199: Train loss=-20675.797070, Validation loss=-20720.925000 (saved)\n",
      "Finished training with best train loss: -20729.588086 and validation loss: -20720.925000\n"
     ]
    }
   ],
   "source": [
    "x_ood_test = x_ood_test.to(device)\n",
    "start = time.time()\n",
    "model.fit(train_loader, valid_loader, x_test=x_ood_test, epochs=epochs, lr=lr, step_size=step_size, gamma=gamma, tpred = torch.tensor(tpred).to(device), dataset_class = dataset_class, t=t.to(device), grid_train=grid.to(device))\n",
    "stop = time.time()\n",
    "# print(stop-start)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 186,
   "id": "61efde5e-a31a-4ec9-87ee-17cf2402b98d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# torch.save(model.state_dict(), \"./pme_e2e_var_update_crps.pt\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 187,
   "id": "fb522e5a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# %timeit model.fit(train_loader, valid_loader, x_test=x_ood_test, epochs=1, lr=lr, step_size=step_size, gamma=gamma, tpred = torch.tensor(tpred).to(device), dataset_class = dataset_class, t=t.to(device), grid_train=grid.to(device))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 188,
   "id": "41b11e17-4840-42fe-9a5b-34d65f0a77d7",
   "metadata": {},
   "outputs": [],
   "source": [
    "from nonlinear_projection import project_and_stats\n",
    "\n",
    "def test(model, test_loader, **test_params):\n",
    "    test_type = test_params.get(\"test_type\", \"id\")\n",
    "    mu = []\n",
    "    var = []\n",
    "    results = {}\n",
    "    results[\"loss\"] = 0.0\n",
    "\n",
    "    model = model.to(device)\n",
    "\n",
    "    with torch.no_grad():\n",
    "        for batch_idx, batch in enumerate(test_loader):\n",
    "            x, y = batch\n",
    "            x, y = x.to(device), y.to(device)\n",
    "\n",
    "            out = model(x)\n",
    "\n",
    "            _mu, _var = out\n",
    "            _std = torch.sqrt(_var)\n",
    "\n",
    "            # out = model.base_model._apply_constraints(_mu, _std, x, t, tpred, grid, dataset_class)\n",
    "\n",
    "\n",
    "            # nf,nx,nt,_ = _mu.shape\n",
    "\n",
    "            # _mu = _mu.view(nf, -1)\n",
    "            # _var = _var.view(nf, -1)\n",
    "            # _m = x.view(nf, -1)\n",
    "\n",
    "            # # print(_m)\n",
    "\n",
    "            # u_proj, u_var = project_and_stats(torch.relu(_mu), _var, _m, model.full_residual, max_iter=30)\n",
    "\n",
    "            # # print(u_proj, u_var)\n",
    "\n",
    "            # if  u_proj.isnan().any().item() or  u_var.isnan().any().item():\n",
    "            #     print(\"any NaN in new_mu?\", u_proj.isnan().any().item())\n",
    "            #     # print(\"min new_var before clamp:\", u_var.min().item())\n",
    "            #     print(\"any NaN in new_var before clamp?\", u_var.isnan().any().item())\n",
    "            #     # new_var = new_var.clamp(min=eps)\n",
    "            #     # print(\"min new_var after clamp:\", new_var.min().item())\n",
    "\n",
    "            # out = (u_proj.view(nf,nx,nt,1), u_var.view(nf,nx,nt,1))\n",
    "\n",
    "            # if model.probconserv:\n",
    "            #     _mu, _var = out\n",
    "            #     _std = torch.sqrt(_var)\n",
    "            #     mass_rhs_func = dataset_class.get_mass_rhs_func(x=x)\n",
    "            #     new_mu, new_std, _, mass_rhs = probconserv.apply_constraint(\n",
    "            #                                                     mu=_mu[:, :, :, 0], \n",
    "            #                                                     std=_std[:, :, :, 0], \n",
    "            #                                                     mass_rhs_func=mass_rhs_func, \n",
    "            #                                                     t=t, \n",
    "            #                                                     tpred=tpred, \n",
    "            #                                                     grid_train=grid, \n",
    "            #                                                     precis_g=np.inf,\n",
    "            #                                                     second_deriv_alpha=None,\n",
    "            #                                                     )\n",
    "            #     out = (new_mu.unsqueeze(-1), torch.square(new_std).unsqueeze(-1))\n",
    "\n",
    "            results[\"loss\"] += model.loss_func(out, y).item()\n",
    "            utils.compute_all_metrics(out, y, results)\n",
    "\n",
    "            if uq:\n",
    "                mu.append(out[0].detach().cpu())\n",
    "                var.append(out[1].detach().cpu())\n",
    "            else:\n",
    "                mu.append(out.detach().cpu())\n",
    "\n",
    "    # print(results['mse'])\n",
    "    # print(len(test_loader.dataset))\n",
    "\n",
    "    for key in results.keys():\n",
    "        if not key.endswith(\"by_example\"):\n",
    "            results[key] /= len(test_loader.dataset)\n",
    "        if type(results[key]) == torch.Tensor:\n",
    "            results[key] = results[key].tolist()\n",
    "\n",
    "    # Plot\n",
    "    mu = torch.cat(mu, dim=0)\n",
    "    if uq:\n",
    "        var = torch.cat(var, dim=0)\n",
    "        std = torch.sqrt(var)\n",
    "    else:\n",
    "        var = None\n",
    "        std = None\n",
    "    x = test_loader.dataset.tensors[0]\n",
    "    y = test_loader.dataset.tensors[1]\n",
    "\n",
    "    if uq:\n",
    "        results[\"nMeRCI_all\"] = utils.compute_nMeRCI(mu, var, y).item()\n",
    "        results[\"rmsce_all\"] = utils.compute_rmsce(mu, var, y).item()\n",
    "\n",
    "        if is_probconserv:\n",
    "            print(\"Here\")\n",
    "            mass_rhs_func = dataset_class.get_mass_rhs_func(x=x)\n",
    "            new_mu, new_std, _, mass_rhs = probconserv.apply_constraint(\n",
    "                mu=mu[:, :, :, 0], \n",
    "                std=std[:, :, :, 0], \n",
    "                mass_rhs_func=mass_rhs_func, \n",
    "                t=t, \n",
    "                tpred=tpred, \n",
    "                grid_train=grid, \n",
    "                precis_g=np.inf,\n",
    "                second_deriv_alpha=None,\n",
    "            )\n",
    "            new_mu = new_mu[:, :, :, None]\n",
    "            new_std = new_std[:, :, :, None]\n",
    "            new_var = new_std**2\n",
    "\n",
    "            probconserv_results = utils.compute_all_metrics((new_mu, new_var), y, {})\n",
    "            for key in probconserv_results.keys():\n",
    "                if not key.endswith(\"by_example\"):\n",
    "                    probconserv_results[key] /= len(test_loader.dataset)\n",
    "                if type(probconserv_results[key]) == torch.Tensor:\n",
    "                    probconserv_results[key] = probconserv_results[key].tolist()\n",
    "\n",
    "            probconserv_results[\"nMeRCI_all\"] = utils.compute_nMeRCI(new_mu, new_var, y).item()\n",
    "            probconserv_results[\"rmsce_all\"] = utils.compute_rmsce(new_mu, new_var, y).item()\n",
    "\n",
    "            cerr = (probconserv.get_empirical_mass_rhs(mu[:, :,  :, 0]) - mass_rhs).abs().sum(dim=-1)\n",
    "            new_cerr = (probconserv.get_empirical_mass_rhs(new_mu[:, :, :, 0]) - mass_rhs).abs().sum(dim=-1)\n",
    "\n",
    "            results[\"cerr_by_example\"] = cerr.tolist()\n",
    "            results[\"mcerr\"] = cerr.mean().item()\n",
    "            probconserv_results[\"cerr_by_example\"] = new_cerr.tolist()\n",
    "            probconserv_results[\"mcerr\"] = new_cerr.mean().item()\n",
    "\n",
    "            for key in probconserv_results.keys():\n",
    "                results[f\"pc.{key}\"] = probconserv_results[key]\n",
    "    \n",
    "    # results[\"time\"] = utils.compute_forward_time(model, x[:batch_size].to(device), repetitions=10)\n",
    "    results[\"n_params\"] = utils.compute_n_params(model)\n",
    "    results[\"n_flops\"] = utils.compute_n_flops(model_name, Np=n_x*n_t, fno_modes=fno_modes, fno_width=fno_width, n_layers=4, n_models=n_models)\n",
    "\n",
    "    dataset_params_correct_type = dataset_params if test_type == \"id\" or test_type == \"train\" else ood_dataset_params\n",
    "\n",
    "    mse_by_example = torch.tensor(results[\"mse_by_example\"])\n",
    "    random_idx = np.random.choice(mse_by_example.shape[0])\n",
    "    _, worst_idx = mse_by_example.max(dim=0)\n",
    "    _, best_idx = mse_by_example.min(dim=0)\n",
    "    _, median_idx = mse_by_example.median(dim=0)\n",
    "\n",
    "    for example_name, example_idx in zip([\"random\", \"worst\", \"best\", \"median\"], [random_idx, worst_idx, best_idx, median_idx]):\n",
    "        if uq:\n",
    "            results[f\"examples.{example_name}\"] = (mu[example_idx].tolist(), var[example_idx].tolist(), y[example_idx].tolist(), x[example_idx].tolist())\n",
    "            if is_probconserv:\n",
    "                results[f\"pc.examples.{example_name}\"] = (new_mu[example_idx].tolist(), new_var[example_idx].tolist(), y[example_idx].tolist(), x[example_idx].tolist())\n",
    "        else:\n",
    "            results[f\"examples.{example_name}\"] = (mu[example_idx].tolist(), None, y[example_idx].tolist(), x[example_idx].tolist())\n",
    "\n",
    "        # prefix = f\"{test_type}_{example_name}_params={dataset_params_correct_type}\"\n",
    "        # plot_and_save(prefix, example_idx, x.squeeze(-1), y.squeeze(-1), mu.squeeze(-1), std.squeeze(-1) if std is not None else None)\n",
    "\n",
    "    # utils.dict_to_file({\"test_type\": test_type, \"params\": dataset_params_correct_type, \"results\": results}, \n",
    "    #                    f\"{run_folder}/results_{test_type}_params={dataset_params_correct_type}.json\")\n",
    "\n",
    "    return results\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 189,
   "id": "58fe7954-8db9-47db-ba7c-39fde51a4146",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Here\n",
      "Here\n",
      "Here\n",
      "Here\n",
      "Train results\n",
      "MSE: 0.00033117315906565635\n",
      "n-MeRCI: 1.1552518606185913\n",
      "RMSCE: 0.25469762086868286\n",
      "Cerr: 0.05766499787569046\n",
      "In-domain results\n",
      "MSE: 0.00033166004344820977\n",
      "n-MeRCI: 0.9276496767997742\n",
      "RMSCE: 0.26144203543663025\n",
      "ProbConserv Results\n",
      "MSE: 0.000602881871163845\n",
      "n-MeRCI: 1.1581425666809082\n",
      "RMSCE: 0.25104549527168274\n",
      "Cerr: 0.05908995494246483\n",
      "Prob_Cerr: 3.3374286090293026e-07\n",
      "Here\n",
      "\n",
      "\n",
      "Out-of-domain results\n",
      "MSE: 0.00034174498170614243\n",
      "n-MeRCI: 1.0854239463806152\n",
      "RMSCE: 0.2542724907398224\n",
      "ProbConserv Results\n",
      "MSE: 0.000619906410574913\n",
      "n-MeRCI: 1.2926985025405884\n",
      "RMSCE: 0.2436477690935135\n",
      "Cerr: 0.05977409705519676\n",
      "Prob_Cerr: 3.174638152358966e-07\n"
     ]
    }
   ],
   "source": [
    "is_probconserv = True\n",
    "\n",
    "train_loader_no_shuffle = torch.utils.data.DataLoader(train_loader.dataset, batch_size=batch_size, shuffle=False)\n",
    "train_results = test(model, train_loader_no_shuffle, test_type=\"train\")\n",
    "id_results = test(model, id_test_loader, test_type=\"id\")\n",
    "\n",
    "if is_train:\n",
    "    train_loader_no_shuffle = torch.utils.data.DataLoader(train_loader.dataset, batch_size=batch_size, shuffle=False)\n",
    "    train_results = test(model, train_loader_no_shuffle, test_type=\"train\")\n",
    "    id_results = test(model, id_test_loader, test_type=\"id\")\n",
    "\n",
    "    print(\"Train results\")\n",
    "    print(f\"MSE: {train_results['mse']}\")\n",
    "    print(f\"n-MeRCI: {train_results['nMeRCI_all']}\")\n",
    "    print(f\"RMSCE: {train_results['rmsce_all']}\")\n",
    "    print(f\"Cerr: {train_results['mcerr']}\")\n",
    "\n",
    "    \n",
    "\n",
    "    print(\"In-domain results\")\n",
    "    print(f\"MSE: {id_results['mse']}\")\n",
    "    print(f\"n-MeRCI: {id_results['nMeRCI_all']}\")\n",
    "    print(f\"RMSCE: {id_results['rmsce_all']}\")\n",
    "\n",
    "    if is_probconserv:\n",
    "        print(\"ProbConserv Results\")\n",
    "        print(f\"MSE: {id_results['pc.mse']}\")\n",
    "        print(f\"n-MeRCI: {id_results['pc.nMeRCI_all']}\")\n",
    "        print(f\"RMSCE: {id_results['pc.rmsce_all']}\")\n",
    "        print(f\"Cerr: {id_results['mcerr']}\")\n",
    "        print(f\"Prob_Cerr: {id_results['pc.mcerr']}\")\n",
    "        \n",
    "\n",
    "ood_results = test(model, ood_test_loader, test_type=\"ood\")\n",
    "\n",
    "print(\"\\n\")\n",
    "print(\"Out-of-domain results\")\n",
    "print(f\"MSE: {ood_results['mse']}\")\n",
    "print(f\"n-MeRCI: {ood_results['nMeRCI_all']}\")\n",
    "print(f\"RMSCE: {ood_results['rmsce_all']}\")\n",
    "\n",
    "if is_probconserv:\n",
    "    print(\"ProbConserv Results\")\n",
    "    print(f\"MSE: {ood_results['pc.mse']}\")\n",
    "    print(f\"n-MeRCI: {ood_results['pc.nMeRCI_all']}\")\n",
    "    print(f\"RMSCE: {ood_results['pc.rmsce_all']}\")\n",
    "    print(f\"Cerr: {ood_results['mcerr']}\")\n",
    "    print(f\"Prob_Cerr: {ood_results['pc.mcerr']}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 190,
   "id": "00406482-eb2b-41df-b458-1d230c92ff1a",
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_statistics(\n",
    "    model, \n",
    "    x_data, \n",
    "    y_data, \n",
    "    t, \n",
    "    tpred, \n",
    "    grid, \n",
    "    dataset_class, \n",
    "    apply_probconserv=False, \n",
    "    plot=False,\n",
    "    x_data_test=None, \n",
    "    y_data_test=None,\n",
    "    return_latex=False,\n",
    "    name=\"Model\"\n",
    "):\n",
    "    import torch\n",
    "    import utils\n",
    "    import probconserv\n",
    "    import matplotlib.pyplot as plt\n",
    "\n",
    "    device = next(model.parameters()).device\n",
    "    x_data = x_data.to(device)\n",
    "\n",
    "    with torch.no_grad():\n",
    "        out = model(x_data)\n",
    "\n",
    "    if isinstance(out, tuple):\n",
    "        mu, var = out[0].cpu(), out[1].cpu()\n",
    "        std = torch.sqrt(var)\n",
    "    else:\n",
    "        mu = out.cpu()\n",
    "        std = torch.zeros_like(mu)\n",
    "        var = torch.square(std)\n",
    "\n",
    "    x_cpu = x_data.cpu()\n",
    "    mass_rhs_func = dataset_class.get_mass_rhs_func(x=x_cpu)\n",
    "\n",
    "    if apply_probconserv:\n",
    "        new_mu, new_std, _, mass_rhs = probconserv.apply_constraint(\n",
    "            mu=mu[:, :, :, 0],\n",
    "            std=std[:, :, :, 0],\n",
    "            mass_rhs_func=mass_rhs_func,\n",
    "            t=t,\n",
    "            tpred=tpred,\n",
    "            grid_train=grid,\n",
    "            precis_g=float('inf'),\n",
    "            second_deriv_alpha=None,\n",
    "        )\n",
    "        mu = new_mu.unsqueeze(-1)\n",
    "        std = new_std.unsqueeze(-1)\n",
    "        var = torch.square(std)\n",
    "        cerr = (probconserv.get_empirical_mass_rhs(mu[:, :, :, 0]) - mass_rhs).abs().sum(dim=-1)\n",
    "    else:\n",
    "        t_sliced = t[slice(*tpred)]\n",
    "        ts = repeat(t_sliced, \"nt -> nf nt\", nf=mu.shape[0])\n",
    "        xs = repeat(grid, \"nx -> nf nx\", nf=mu.shape[0])\n",
    "        inputs = meshgrid(ts, xs)\n",
    "        cerr = (probconserv.get_empirical_mass_rhs(mu[:, :, :, 0]) - mass_rhs_func(inputs)).abs().sum(dim=-1)\n",
    "\n",
    "    stats = utils.compute_all_metrics_avg((mu, var), y_data, {})\n",
    "    stats[\"nMeRCI_all\"] = utils.compute_nMeRCI(mu, var, y_data).item()\n",
    "    stats[\"rmsce_all\"] = utils.compute_rmsce(mu, var, y_data).item()\n",
    "    stats[\"cerr_by_example\"] = cerr.tolist()\n",
    "    stats[\"mcerr\"] = cerr.mean().item()\n",
    "\n",
    "    # --- Test dataset ---\n",
    "    test_stats = None\n",
    "    if x_data_test is not None and y_data_test is not None:\n",
    "        x_data_test = x_data_test.to(device)\n",
    "        with torch.no_grad():\n",
    "            test_out = model(x_data_test)\n",
    "\n",
    "        if isinstance(test_out, tuple):\n",
    "            mu_test, var_test = test_out[0].cpu(), test_out[1].cpu()\n",
    "            std_test = torch.sqrt(var_test)\n",
    "        else:\n",
    "            mu_test = test_out.cpu()\n",
    "            std_test = torch.zeros_like(mu_test)\n",
    "            var_test = torch.square(std_test)\n",
    "\n",
    "        x_test_cpu = x_data_test.cpu()\n",
    "        test_mass_rhs_func = dataset_class.get_mass_rhs_func(x=x_test_cpu)\n",
    "\n",
    "        if apply_probconserv:\n",
    "            new_mu_test, new_std_test, _, test_mass_rhs = probconserv.apply_constraint(\n",
    "                mu=mu_test[:, :, :, 0],\n",
    "                std=std_test[:, :, :, 0],\n",
    "                mass_rhs_func=test_mass_rhs_func,\n",
    "                t=t,\n",
    "                tpred=tpred,\n",
    "                grid_train=grid,\n",
    "                precis_g=float('inf'),\n",
    "                second_deriv_alpha=None,\n",
    "            )\n",
    "            mu_test = new_mu_test.unsqueeze(-1)\n",
    "            std_test = new_std_test.unsqueeze(-1)\n",
    "            var_test = torch.square(std_test)\n",
    "            cerr_test = (probconserv.get_empirical_mass_rhs(mu_test[:, :, :, 0]) - test_mass_rhs).abs().sum(dim=-1)\n",
    "        else:\n",
    "            t_sliced = t[slice(*tpred)]\n",
    "            ts = repeat(t_sliced, \"nt -> nf nt\", nf=mu_test.shape[0])\n",
    "            xs = repeat(grid, \"nx -> nf nx\", nf=mu_test.shape[0])\n",
    "            inputs = meshgrid(ts, xs)\n",
    "            cerr_test = (probconserv.get_empirical_mass_rhs(mu_test[:, :, :, 0]) - test_mass_rhs_func(inputs)).abs().sum(dim=-1)\n",
    "\n",
    "        test_stats = utils.compute_all_metrics_avg((mu_test, var_test), y_data_test, {})\n",
    "        test_stats[\"nMeRCI_all\"] = utils.compute_nMeRCI(mu_test, var_test, y_data_test).item()\n",
    "        test_stats[\"rmsce_all\"] = utils.compute_rmsce(mu_test, var_test, y_data_test).item()\n",
    "        test_stats[\"cerr_by_example\"] = cerr_test.tolist()\n",
    "        test_stats[\"mcerr\"] = cerr_test.mean().item()\n",
    "\n",
    "    # --- Optional plot ---\n",
    "    if plot:\n",
    "        t_idx = 1\n",
    "        param_idx = 0\n",
    "        with torch.no_grad():\n",
    "            plt.ylabel(f\"u(x, t={t[slice(*tpred)][t_idx]:.2f})\")\n",
    "            plt.xlabel(\"x\")\n",
    "            plt.title(f\"Predicted vs True (param = {x_data[param_idx,0,0,0].item():.2f})\")\n",
    "            mu_plot = mu[param_idx, :, t_idx, 0]\n",
    "            std_plot = std[param_idx, :, t_idx, 0]\n",
    "            y_true_plot = y_data[param_idx, :, t_idx, 0]\n",
    "            plt.plot(grid, mu_plot, '--', lw=2, label=\"μ ± 3σ\")\n",
    "            plt.fill_between(grid, mu_plot + 3*std_plot, mu_plot - 3*std_plot, alpha=0.2)\n",
    "            plt.plot(grid, y_true_plot, color=\"green\", label=\"true\")\n",
    "            plt.legend()\n",
    "            plt.show()\n",
    "\n",
    "    # --- Optional LaTeX row ---\n",
    "    latex_row = None\n",
    "    if return_latex and test_stats:\n",
    "        latex_row = (\n",
    "            f\"{name} & \"\n",
    "            f\"{stats['mse']:.2E} & {stats['nMeRCI_all']:.2E} & {stats['rmsce_all']:.2E} & {stats['mcerr']:.2E} & {stats['crps']:.2E} & \"\n",
    "            f\"{test_stats['mse']:.2E} & {test_stats['nMeRCI_all']:.2E} & {test_stats['rmsce_all']:.2E} & {test_stats['mcerr']:.2E} & {test_stats['crps']:.2E} \\\\\\\\\"\n",
    "        )\n",
    "\n",
    "    return (stats, test_stats, latex_row) if return_latex else (stats, test_stats)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 191,
   "id": "a48efcc2-03a0-47c2-8af5-62cd2fd6e4e4",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_stats, test_stats, latex = compute_statistics(\n",
    "    model,\n",
    "    x_train, y_train,\n",
    "    x_data_test=x_ood_test, \n",
    "    y_data_test=y_ood_test,\n",
    "    t=t, tpred=tpred, grid=grid,\n",
    "    dataset_class=dataset_class,\n",
    "    apply_probconserv=False,\n",
    "    plot=False,\n",
    "    return_latex=True,\n",
    "    name=\"Unconstrained\"\n",
    ")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 192,
   "id": "ea1e01de-5ed3-489d-ab73-3173bd8181b0",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "341.74494445323944"
      ]
     },
     "execution_count": 192,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_stats['mse']*1e6"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 193,
   "id": "837265e6-c08a-4bf7-9200-8835499a555b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "35.894349217414856"
      ]
     },
     "execution_count": 193,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_stats['crps']*1e4"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 194,
   "id": "aa075212-97ba-4823-bac0-e43558b27ce2",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "59.774111956357956"
      ]
     },
     "execution_count": 194,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_stats['mcerr']*1e3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 195,
   "id": "ca88a17e-fc02-4347-8c26-74fa8bd4d1d2",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'Unconstrained & 3.31E-04 & 1.16E+00 & 2.55E-01 & 5.77E-02 & 3.57E-03 & 3.42E-04 & 1.09E+00 & 2.54E-01 & 5.98E-02 & 3.59E-03 \\\\\\\\'"
      ]
     },
     "execution_count": 195,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "latex"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 196,
   "id": "98acb797",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.000602881871163845"
      ]
     },
     "execution_count": 196,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "id_results['pc.mse']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1ec3729e-d8e6-454f-8e4e-6b05f44a94f4",
   "metadata": {},
   "outputs": [],
   "source": [
    "out = model(x_ood_test.to(device))\n",
    "x = ood_test_loader.dataset.tensors[0]\n",
    "y = ood_test_loader.dataset.tensors[1]\n",
    "mass_rhs_func = dataset_class.get_mass_rhs_func(x=x)\n",
    "\n",
    "if model.probconserv:\n",
    "    _mu, _var, = out[0].cpu(), out[1].cpu()\n",
    "    _std = torch.sqrt(_var)\n",
    "    mass_rhs_func = dataset_class.get_mass_rhs_func(x=x)\n",
    "    new_mu, new_std, _, mass_rhs = probconserv.apply_constraint(\n",
    "                                                    mu=_mu[:, :, :, 0], \n",
    "                                                    std=_std[:, :, :, 0], \n",
    "                                                    mass_rhs_func=mass_rhs_func, \n",
    "                                                    t=t, \n",
    "                                                    tpred=tpred, \n",
    "                                                    grid_train=grid, \n",
    "                                                    precis_g=np.inf,\n",
    "                                                    second_deriv_alpha=None,\n",
    "                                                    )\n",
    "    out = (new_mu.unsqueeze(-1), torch.square(new_std).unsqueeze(-1))\n",
    "\n",
    "mu, var = out\n",
    "nf,nx,nt,_ = mu.shape\n",
    "\n",
    "# _mu = mu.view(nf, -1)\n",
    "# _var = var.view(nf, -1)\n",
    "# _m = x.view(nf, -1).to(device)\n",
    "\n",
    "# # print(_m)\n",
    "\n",
    "# u_proj, u_var = project_and_stats(torch.relu(_mu), _var, _m, model.full_residual, max_iter=30)\n",
    "\n",
    "# out = (u_proj.view(nf,nx,nt,1), u_var .view(nf,nx,nt,1))\n",
    "\n",
    "mu, var, = out[0].cpu(), out[1].cpu()\n",
    "\n",
    "std = torch.sqrt(var)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fa65af1d-7356-4892-acff-81125e64347f",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.norm(utils.compute_crps_by_example(mu, var, y_ood_test))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "11e418ff-cbf6-4455-a062-676e2df0a1cc",
   "metadata": {},
   "outputs": [],
   "source": [
    "# u_proj, u_var = project_and_stats(nmu.view(nf,-1).to(device), nvar.view(nf,-1).to(device), _m, model.full_residual, max_iter=30)\n",
    "# nmu, nvar = (u_proj.view(nf,nx,nt,1), u_var.view(nf,nx,nt,1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b4d043cc-dd31-4a21-ad06-663438462c21",
   "metadata": {},
   "outputs": [],
   "source": [
    "# torch.norm(utils.compute_crps_by_example(mu.cpu(), nvar.cpu(), y_train))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a2ecdfcd-84ec-49f2-9d09-51549d332c28",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.func import vmap"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5e96a231-5af5-4a84-af56-f3069403e0b1",
   "metadata": {},
   "outputs": [],
   "source": [
    "# torch.norm(vmap(model.full_residual)(torch.relu(mu.to(device)), _m),dim = 0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4cd9a8f9-b2d7-43c3-982b-f4e547fcaf20",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.nn.MSELoss()(y_ood_test[0,:,1,:],mu[0,:,1,:])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "995e5a6f-63e8-4cf3-85f9-73f8d7d30009",
   "metadata": {},
   "outputs": [],
   "source": [
    "# t_idx = len(t[slice(*tpred)])//2\n",
    "t_idx = 10\n",
    "\n",
    "for parameters_idx in range(0, 1, 5):\n",
    "    with torch.no_grad():\n",
    "        plt.ylabel(\"u(x,t={t:.2f})\".format(t=t[slice(*tpred)][t_idx]))\n",
    "        plt.title(\"Learning {dataset} for parameter = {k:.2f}\".format(k = x_ood_test[parameters_idx,0,0,0], dataset = dataset))\n",
    "        plt.xlabel(\"x\")\n",
    "        # plt.plot(grid, new_mu[parameters_idx,:,t_idx,0], '--', lw=2, label = \"predicted $\\mu$ and $\\pm 3\\sigma$ (probconserv)\")\n",
    "        # plt.fill_between(grid, new_mu[parameters_idx,:,t_idx,0]+3*new_std[parameters_idx,:,t_idx,0], new_mu[parameters_idx,:,t_idx,0]-3*new_std[parameters_idx,:,t_idx,0], alpha=0.2)\n",
    "        # plt.plot(grid, y_train[parameters_idx,:,t_idx,0], color = \"green\", label = \"true\")\n",
    "        \n",
    "        plt.plot(grid, mu[parameters_idx,:,t_idx,0], '--', lw=2, label = \"predicted $\\mu$ and $\\pm 3\\sigma$ (probharde2e)\")\n",
    "        plt.fill_between(grid, mu[parameters_idx,:,t_idx,0]+3*std[parameters_idx,:,t_idx,0], mu[parameters_idx,:,t_idx,0]-3*std[parameters_idx,:,t_idx,0], alpha=0.2)\n",
    "        plt.plot(grid, y_ood_test[parameters_idx,:,t_idx,0], color = \"green\", label = \"true\")\n",
    "        \n",
    "        print(torch.norm(y_train[parameters_idx,:,t_idx,0] - mu[parameters_idx,:,t_idx,0]))\n",
    "        plt.legend()\n",
    "        # plt.ylim(-1.0,1.5)\n",
    "        plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4dfa6e03-c09a-406f-abf8-d41d68afa9f6",
   "metadata": {},
   "outputs": [],
   "source": [
    "mu.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6376bd34-380d-43f0-b9ae-e21dfa77d686",
   "metadata": {},
   "outputs": [],
   "source": [
    "var.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1824da2b-18e6-4b1f-80d5-4e5d7e167a9d",
   "metadata": {},
   "outputs": [],
   "source": [
    "grid.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "01feae7c-7aba-49ca-9723-f00545566493",
   "metadata": {},
   "outputs": [],
   "source": [
    "slice(*tpred)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "67a10613-e7ed-48d3-bd48-3cd1f6ace4b1",
   "metadata": {},
   "outputs": [],
   "source": [
    "t[slice(*tpred)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eecfe6c0-2ee6-4918-b2b8-7856bf097fdf",
   "metadata": {},
   "outputs": [],
   "source": [
    "dt = t[slice(*tpred)][1] - t[slice(*tpred)][0]\n",
    "dx = grid[1] - grid[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7e3bc062-c816-49f6-bc44-103fa68cadbc",
   "metadata": {},
   "outputs": [],
   "source": [
    "def conservation_residual_all_times(u_flat, m):\n",
    "    u = u_flat.view(nx, nt)\n",
    "\n",
    "    # Initial mass\n",
    "    mass_0 = torch.sum(u[:, 0]) * dx  # shape: scalar\n",
    "\n",
    "    # Compute u^m * ∂u/∂x at left and right boundary for all time steps\n",
    "    um = u ** m\n",
    "    ux_left = (u[1, :] - u[0, :]) / dx         # shape: (nt,)\n",
    "    ux_right = (u[-1, :] - u[-2, :]) / dx      # shape: (nt,)\n",
    "    \n",
    "    flux_left = um[0, :] * ux_left             # u^m * u_x at x = 0\n",
    "    flux_right = um[-1, :] * ux_right          # u^m * u_x at x = 1\n",
    "\n",
    "    net_flux = torch.cumsum(flux_right - flux_left, dim=0) * dt  # ∫₀^t (flux_right - flux_left)\n",
    "\n",
    "    # Total mass at each time\n",
    "    mass_t = torch.sum(u, dim=0) * dx          # shape: (nt,)\n",
    "\n",
    "    # Conservation residual\n",
    "    residue = mass_t - mass_0 + net_flux       # sign flipped because outflow reduces mass\n",
    "\n",
    "    return residue  # shape: (nt,)\n",
    "\n",
    "# Combine\n",
    "def full_residual(u_flat):\n",
    "    return torch.cat([conservation_residual_all_times(u_flat)])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7567e79a-911c-4214-8d55-f2e05643e4e1",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch.func import vmap, jacrev\n",
    "\n",
    "# --- Projection code ---\n",
    "def fast_project_batched(xi_batch, h_func, max_iter=30):\n",
    "    B, n = xi_batch.shape\n",
    "\n",
    "    def newton_step(u, xi):\n",
    "        h_val = h_func(u)\n",
    "        if h_val.ndim == 1:\n",
    "            h_val = h_val.unsqueeze(-1)\n",
    "        J = jacrev(h_func)(u)\n",
    "        if J.ndim == 1:\n",
    "            J = J.unsqueeze(0)\n",
    "        delta = (xi - u).unsqueeze(-1)\n",
    "        JJt = J @ J.transpose(-2, -1)\n",
    "        rhs = J @ delta + h_val\n",
    "        lambda_ = torch.linalg.solve(JJt, rhs)\n",
    "        du = delta - J.transpose(-2, -1) @ lambda_\n",
    "        return u + du.squeeze(-1)\n",
    "\n",
    "    def loop(xi):\n",
    "        u = xi.clone()\n",
    "        for _ in range(max_iter):\n",
    "            u = newton_step(u, xi)\n",
    "        return u\n",
    "\n",
    "    return vmap(loop)(xi_batch)\n",
    "\n",
    "def fast_project_weighted(xi_batch, sigma_batch, h_func, max_iter=30):\n",
    "    \"\"\"\n",
    "    Solve: argmin_u ||u - xi||^2_{sigma^{-1}} s.t. h(u) = 0\n",
    "    Args:\n",
    "        xi_batch: (B, n)\n",
    "        sigma_batch: (B, n)\n",
    "        h_func: function u → ℝ^m\n",
    "    Returns:\n",
    "        u_proj: (B, n)\n",
    "    \"\"\"\n",
    "    B, n = xi_batch.shape\n",
    "\n",
    "    def newton_step(u, xi, sigma):\n",
    "        sigma_inv = 1.0 / (sigma)         # shape (n,)\n",
    "        h_val = h_func(u)\n",
    "        if h_val.ndim == 1:\n",
    "            h_val = h_val.unsqueeze(-1)\n",
    "        J = jacrev(h_func)(u)\n",
    "        if J.ndim == 1:\n",
    "            J = J.unsqueeze(0)\n",
    "\n",
    "        delta = (xi - u).unsqueeze(-1)           # shape (n, 1)\n",
    "\n",
    "        # Use elementwise weighting: JW = J * sigma_inv[None, :]\n",
    "        JW = J * sigma_inv.unsqueeze(0)              # shape (m, n)\n",
    "        JJt = JW @ J.transpose(-2, -1)               # (m, m)\n",
    "        rhs = JW @ delta + h_val                     # (m, 1)\n",
    "\n",
    "        lambda_ = torch.linalg.solve(JJt, rhs)       # (m, 1)\n",
    "        Jt_lambda = J.transpose(-2, -1) @ lambda_    # (n, 1)\n",
    "\n",
    "        du = sigma_inv.unsqueeze(-1) * (delta - Jt_lambda)  # elementwise multiply\n",
    "\n",
    "\n",
    "        return u + du.squeeze(-1)\n",
    "\n",
    "    def loop(xi, sigma):\n",
    "        u = xi.clone()\n",
    "        for _ in range(max_iter):\n",
    "            u = newton_step(u, xi, sigma)\n",
    "        return u\n",
    "\n",
    "    return vmap(loop)(xi_batch, sigma_batch)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7ab81cad-a983-472d-a561-627d397b2938",
   "metadata": {},
   "outputs": [],
   "source": [
    "nf,nx,nt, _ = mu.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5d8bdf1f-e427-4098-87bb-11017ac65088",
   "metadata": {},
   "outputs": [],
   "source": [
    "mu_flat = mu.view(nf, -1).to(device)\n",
    "var_flat = var.view(nf, -1).to(device)\n",
    "m_flat = x_ood_test.view(nf, -1).to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "319cd929-a2b9-4410-ab6e-eae65b91ee5c",
   "metadata": {},
   "outputs": [],
   "source": [
    "x_ood_test.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5d9cbb76-f1a6-4bd0-8a1f-1808abb5539e",
   "metadata": {},
   "outputs": [],
   "source": [
    "mu_flat.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "27fa38f0-9d5f-49b6-bcc8-0f455cd0ff99",
   "metadata": {},
   "outputs": [],
   "source": [
    "# def conservation_residual_all_times(u_flat, m_flat):\n",
    "#     u = u_flat.view(nx, nt)\n",
    "#     m = m_flat.view(nx, nt)\n",
    "\n",
    "#     # Initial mass\n",
    "#     mass_0 = torch.sum(u[:, 0]) * dx  # shape: scalar\n",
    "\n",
    "#     # Compute u^m * ∂u/∂x at left and right boundary for all time steps\n",
    "#     um = u ** m\n",
    "#     ux_left = (u[1, :] - u[0, :]) / dx         # shape: (nt,)\n",
    "#     ux_right = (u[-1, :] - u[-2, :]) / dx      # shape: (nt,)\n",
    "    \n",
    "#     flux_left = um[0, :] * ux_left             # u^m * u_x at x = 0\n",
    "#     flux_right = um[-1, :] * ux_right          # u^m * u_x at x = 1\n",
    "\n",
    "#     net_flux = torch.cumsum(flux_right - flux_left, dim=0) * dt  # ∫₀^t (flux_right - flux_left)\n",
    "\n",
    "#     # Total mass at each time\n",
    "#     mass_t = torch.sum(u, dim=0) * dx          # shape: (nt,)\n",
    "\n",
    "#     # Conservation residual\n",
    "#     residue = mass_t - mass_0 + net_flux       # sign flipped because outflow reduces mass\n",
    "\n",
    "#     return residue[1:]  # shape: (nt,)\n",
    "\n",
    "def conservation_residual_all_times(u_flat, m_flat):\n",
    "    u = u_flat.view(nx, nt)\n",
    "    m = m_flat.view(nx, nt)\n",
    "\n",
    "    def trapz_space(u):\n",
    "        weight = torch.ones_like(u)\n",
    "        weight[0, ...] *= 0.5\n",
    "        weight[-1, ...] *= 0.5\n",
    "        return torch.sum(weight * u, dim=0) * dx\n",
    "\n",
    "    # Initial mass\n",
    "    mass_0 = trapz_space(u[:, 0])  # scalar\n",
    "\n",
    "    # Compute u^m * u_x at both ends using backward differences\n",
    "    um = u ** m\n",
    "    ux_left = (u[0, :] - u[1, :]) / dx\n",
    "    ux_right = (u[-2, :] - u[-1, :]) / dx\n",
    "\n",
    "    flux_diff = um[-1, :] * ux_right - um[0, :] * ux_left  # shape: (nt,)\n",
    "\n",
    "    # Left Riemann sum: net flux up to each time step\n",
    "    flux_increments = torch.cat([\n",
    "        torch.zeros(1, device=u.device),        # net_flux[0] = 0\n",
    "        flux_diff[:-1] * dt\n",
    "    ], dim=0)\n",
    "\n",
    "    net_flux = torch.cumsum(flux_increments, dim=0)\n",
    "\n",
    "    # Mass at each time\n",
    "    mass_t = trapz_space(u)\n",
    "\n",
    "    # Residual = mass_t - mass_0 + net_flux (should be ≈ 0)\n",
    "    residue = mass_t - mass_0 + net_flux\n",
    "\n",
    "    return residue[1:]  # skip t=0\n",
    "\n",
    "\n",
    "def ic_residual(u_flat, u0):\n",
    "    u = u_flat.view(nx, nt)           # shape: (nx, nt)\n",
    "    return u[1:-1, 0]              # shape: (nx,)\n",
    "\n",
    "\n",
    "t_grid = t[slice(*tpred)].clone().to(device)\n",
    "def bc_residual_dirichlet(u_flat, m_flat):\n",
    "    u = u_flat.view(nx, nt)\n",
    "    m = m_flat[0] # for broadcasting\n",
    "\n",
    "    # Construct target boundary profile for left boundary\n",
    "    # t_grid = torch.linspace(0, 1, nt, device=u.device)  # shape: (nt,)\n",
    "    left_bc_target = (m * t_grid) ** (1.0 / m)  # shape: (nt,)\n",
    "    \n",
    "    left_bc_actual = u[0, :]                               # u(x=0, t)\n",
    "    right_bc_actual = u[-1, :]                             # u(x=1, t)\n",
    "\n",
    "    h_left = left_bc_actual - left_bc_target               # (nt,)\n",
    "    h_right = right_bc_actual                              # (nt,)\n",
    "\n",
    "    return torch.cat([h_left, h_right], dim=0)             # shape: (2 × nt,)\n",
    "\n",
    "# Combine\n",
    "def full_residual(u_flat, m_flat):\n",
    "    return torch.cat([ic_residual(u_flat, m_flat),\n",
    "                      conservation_residual_all_times(u_flat, m_flat),                          # (nx,)\n",
    "                      bc_residual_dirichlet(u_flat, m_flat),\n",
    "                     ])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5f9489c5-b02d-4719-b675-af1c38df18fe",
   "metadata": {},
   "outputs": [],
   "source": [
    "def fast_project_weighted_with_m(xi_batch, sigma_batch, m_batch, h_func, max_iter=30):\n",
    "    def newton_step(u, xi, sigma, m):\n",
    "        sigma_inv = 1.0 / (sigma +1e-6)\n",
    "        sigma_inv = torch.ones_like(sigma)\n",
    "        h_val = h_func(u,m)\n",
    "        if h_val.ndim == 1:\n",
    "            h_val = h_val.unsqueeze(-1)\n",
    "        J = jacrev(lambda u_: h_func(u_, m))(u)  # (nt, n)\n",
    "        if J.ndim == 1:\n",
    "            J = J.unsqueeze(0)\n",
    "\n",
    "        # print(J)\n",
    "\n",
    "        delta = (xi - u).unsqueeze(-1)           # shape (n, 1)\n",
    "\n",
    "        # Use elementwise weighting: JW = J * sigma_inv[None, :]\n",
    "        JW = J * sigma_inv.unsqueeze(0)              # shape (m, n)\n",
    "        JJt = JW @ J.transpose(-2, -1)               # (m, m)\n",
    "        rhs = JW @ delta + h_val                     # (m, 1)\n",
    "\n",
    "        lambda_ = torch.linalg.solve(JJt, rhs)       # (m, 1)\n",
    "        Jt_lambda = J.transpose(-2, -1) @ lambda_    # (n, 1)\n",
    "        \n",
    "        du = sigma_inv.unsqueeze(-1) * (delta - Jt_lambda)\n",
    "\n",
    "        # return u + du.squeeze(-1)\n",
    "        return torch.clamp(u + du.squeeze(-1),min=0.0)\n",
    "\n",
    "    def loop(xi, sigma, m):\n",
    "        u = xi.clone()\n",
    "        for _ in range(max_iter):\n",
    "            u = newton_step(u, xi, sigma, m)\n",
    "        return u\n",
    "\n",
    "    return vmap(loop)(xi_batch, sigma_batch, m_batch)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "681dc3a6-49ef-40ad-b78c-a359b0aa81af",
   "metadata": {},
   "outputs": [],
   "source": [
    "def fast_project_weighted_with_m(xi_batch, sigma_batch, m_batch, h_func, max_iter=30):\n",
    "    def newton_step(u_tilde, xi_tilde, sqrt_sigma, m):\n",
    "        # back-transform: u = Σ^{1/2} ũ\n",
    "        u = u_tilde * sqrt_sigma\n",
    "\n",
    "        h_val = h_func(u, m)\n",
    "        if h_val.ndim == 1:\n",
    "            h_val = h_val.unsqueeze(-1)\n",
    "\n",
    "        # J: dh/dũ = dh/du × du/dũ = J_u × diag(sqrt_sigma)\n",
    "        J_u = jacrev(lambda u_: h_func(u_, m))(u)  # (k, p)\n",
    "        J = J_u * sqrt_sigma.unsqueeze(0)          # chain rule\n",
    "\n",
    "        delta = (xi_tilde - u_tilde).unsqueeze(-1)\n",
    "\n",
    "        JJt = J @ J.transpose(-2, -1)\n",
    "        rhs = J @ delta + h_val\n",
    "\n",
    "        #I = torch.eye(JJt.shape[-1], device=JJt.device, dtype=JJt.dtype)\n",
    "        # + 1e-6 * I\n",
    "        lambda_ = torch.linalg.solve(JJt, rhs)\n",
    "\n",
    "        du = delta - J.transpose(-2, -1) @ lambda_\n",
    "        u_tilde_next = u_tilde + du.squeeze(-1)\n",
    "        \n",
    "        return torch.clamp(u_tilde_next,min=0.0)\n",
    "\n",
    "    def loop(xi, sigma, m):\n",
    "        eps = 1e-6\n",
    "        sigma = sigma.clamp(min=eps)\n",
    "        # sigma = torch.ones_like(sigma)\n",
    "        sqrt_sigma = sigma.sqrt()\n",
    "        sqrt_sigma_inv = 1.0 / sqrt_sigma\n",
    "\n",
    "        xi_tilde = xi * sqrt_sigma_inv\n",
    "        u_tilde = xi_tilde.clone()\n",
    "\n",
    "        for _ in range(max_iter):\n",
    "            u_tilde = newton_step(u_tilde, xi_tilde, sqrt_sigma, m)\n",
    "\n",
    "        u_proj = u_tilde * sqrt_sigma  # back-transform\n",
    "        return torch.clamp(u_proj, min=0.0)\n",
    "\n",
    "    return vmap(loop)(xi_batch, sigma_batch, m_batch)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e883c9c9-bb46-4702-baf8-26e3b24a989c",
   "metadata": {},
   "outputs": [],
   "source": [
    "var_flat"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6dfe6a62-07b8-4f10-9815-58ee606220b8",
   "metadata": {},
   "outputs": [],
   "source": [
    "import gc\n",
    "gc.collect()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9decf451-14e0-4e6f-8f4f-bdd4fc1821e1",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.cuda.empty_cache()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "51f5fde6-fd55-4c0b-8d1e-d1883f63600d",
   "metadata": {},
   "outputs": [],
   "source": [
    "full_residual(mu_flat[0,:], m_flat[0,:]).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fc460f0d-5920-4f3e-886d-657a7d747b65",
   "metadata": {},
   "outputs": [],
   "source": [
    "m_flat"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "23d5d420-215b-4320-add6-4d58bf746878",
   "metadata": {},
   "outputs": [],
   "source": [
    "import gc\n",
    "gc.collect()\n",
    "torch.cuda.empty_cache()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3be0869f-f0db-41df-94f4-814aaf315d79",
   "metadata": {},
   "outputs": [],
   "source": [
    "u_proj = fast_project_weighted_with_m(torch.relu(mu_flat), var_flat, m_flat, full_residual)\n",
    "u_proj_reshaped = u_proj.view(nf, nx, nt, 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e8ce1384-f11f-4d5a-8f9d-b96339ac6788",
   "metadata": {},
   "outputs": [],
   "source": [
    "u_proj"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fb3a9c7e-06bc-49b6-898c-21ea991483c9",
   "metadata": {},
   "outputs": [],
   "source": [
    "u_proj[torch.isnan(u_proj)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "957e7c15-a613-4e67-a9cb-1856e51e8dd1",
   "metadata": {},
   "outputs": [],
   "source": [
    "u_proj_reshaped[0,:,0,0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1ccd11e3-69f6-435f-8a39-0dcc9acb9003",
   "metadata": {},
   "outputs": [],
   "source": [
    "mu[0,:,0,0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2b9e3bb2-41d7-49f0-916c-246826661660",
   "metadata": {},
   "outputs": [],
   "source": [
    "full_residual(u_proj_reshaped[0,...].flatten().to(device), m_flat[0,...])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9e7908c2-1b72-40eb-922f-35bfab9a071e",
   "metadata": {},
   "outputs": [],
   "source": [
    "std = torch.sqrt(var)\n",
    "\n",
    "out = model(x_ood_test.to(device))\n",
    "x = ood_test_loader.dataset.tensors[0]\n",
    "y = ood_test_loader.dataset.tensors[1]\n",
    "mass_rhs_func = dataset_class.get_mass_rhs_func(x=x)\n",
    "\n",
    "new_mu, new_std, _, mass_rhs = probconserv.apply_constraint(\n",
    "    mu=mu[:, :, :, 0], \n",
    "    std=std[:, :, :, 0], \n",
    "    mass_rhs_func=mass_rhs_func, \n",
    "    t=t, \n",
    "    tpred=tpred, \n",
    "    grid_train=grid, \n",
    "    precis_g=np.inf,\n",
    "    second_deriv_alpha=None,\n",
    ")\n",
    "new_mu = new_mu[:, :, :, None]\n",
    "new_std = new_std[:, :, :, None]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cedca976-7263-4789-aeaf-1850bca3d083",
   "metadata": {},
   "outputs": [],
   "source": [
    "ic_residual(u_proj[0,...].to(device), m_flat[0,...])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ac4ef7b7-025a-439c-b21a-8fbd9f037f80",
   "metadata": {},
   "outputs": [],
   "source": [
    "ic_residual(torch.abs(new_mu[0,...]).flatten().to(device), m_flat[0,...])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "66fe7479-1fa6-4b3a-9603-43c98c62b13e",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.nn.MSELoss()(mu,y_ood_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "840f0465-2661-4b3e-81b6-04a03a432fcb",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.nn.MSELoss()(u_proj_reshaped.cpu(), y_ood_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5ae9a844-eb88-48d7-a4f0-411d7cb8fb9b",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.nn.MSELoss()(new_mu.cpu(), y_ood_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "48c8eb7f-8a71-4fb9-b007-12fb7d0aa798",
   "metadata": {},
   "outputs": [],
   "source": [
    "vmap(ic_residual)(u_proj,m_flat)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1197d13e-3821-4d20-a7fc-9c4a313ce623",
   "metadata": {},
   "outputs": [],
   "source": [
    "new_mu[0,:,0,0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "31182cc9-f3fa-4b1d-aa63-5f91b505463e",
   "metadata": {},
   "outputs": [],
   "source": [
    "y_train[0,:,0,0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "53fd25a7-c6b3-4d3e-943c-39e0e2c5ad11",
   "metadata": {},
   "outputs": [],
   "source": [
    "u_proj_reshaped[0,:,0,0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5e51ff78-8e4f-40e9-8457-27dcacdafbaa",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch.func import vmap, jacrev\n",
    "\n",
    "class ProjectWeightedImplicit(torch.autograd.Function):\n",
    "    @staticmethod\n",
    "    def forward(ctx, xi_batch, sigma_batch, m_batch, h_func, max_iter=30):\n",
    "        \"\"\"\n",
    "        Batched Newton‐projection with implicit‐diff backward pass.\n",
    "        \n",
    "        Args:\n",
    "            xi_batch:   Tensor of shape (B, p), unconstrained points\n",
    "            sigma_batch: Tensor of shape (B, p), positive weights\n",
    "            m_batch:    Tensor of extra parameters per sample (e.g. (B, …))\n",
    "            h_func:     callable h(u, m) → Tensor of shape (k,) or scalar\n",
    "            max_iter:   number of Newton steps\n",
    "        \n",
    "        Returns:\n",
    "            u_proj: Tensor of shape (B, p), projections s.t. h(u_proj, m)=0\n",
    "        \"\"\"\n",
    "        B, p = xi_batch.shape\n",
    "\n",
    "        def project_one(xi, sigma, m):\n",
    "            # ensure positivity\n",
    "            eps = 1e-6\n",
    "            sigma = sigma.clamp(min=eps)\n",
    "            sqrt_sigma = sigma.sqrt()\n",
    "            inv_sqrt = 1.0 / sqrt_sigma\n",
    "\n",
    "            # transform to tilde‐space\n",
    "            xi_tilde = xi * inv_sqrt\n",
    "            u_tilde  = xi_tilde.clone()\n",
    "\n",
    "            def newton_step(u_tilde):\n",
    "                # back to u‐space\n",
    "                u = u_tilde * sqrt_sigma\n",
    "                h = h_func(u, m)\n",
    "                # make h a vector of shape (k,)\n",
    "                if h.ndim == 0:\n",
    "                    h = h.unsqueeze(0)\n",
    "\n",
    "                # Jacobian ∂h/∂u at u: shape (k, p)\n",
    "                J_u = jacrev(lambda u_: h_func(u_, m))(u)\n",
    "                if J_u.ndim == 1:\n",
    "                    J_u = J_u.unsqueeze(0)\n",
    "                # chain‐rule for tilde: dh/dũ = J_u * diag(sqrt_sigma)\n",
    "                J = J_u * sqrt_sigma.unsqueeze(0)\n",
    "\n",
    "                # compute Newton update in tilde‐space\n",
    "                δ = (xi_tilde - u_tilde).unsqueeze(-1)        # (p,1)\n",
    "                JJt = J @ J.transpose(-2, -1)                  # (k,k)\n",
    "                rhs = J @ δ + h.unsqueeze(-1)                  # (k,1)\n",
    "                I   = torch.eye(JJt.shape[-1], device=JJt.device)\n",
    "                λ   = torch.linalg.solve(JJt + 1e-6*I, rhs)    # (k,1)\n",
    "\n",
    "                du = δ - J.transpose(-2, -1) @ λ               # (p,1)\n",
    "                return torch.clamp(u_tilde + du.squeeze(-1), min=0.0)\n",
    "\n",
    "            # run Newton iterations\n",
    "            for _ in range(max_iter):\n",
    "                u_tilde = newton_step(u_tilde)\n",
    "\n",
    "            # back to original space and clamp\n",
    "            return torch.clamp(u_tilde * sqrt_sigma, min=0.0)\n",
    "\n",
    "        # vectorize over batch\n",
    "        u_proj = vmap(project_one)(xi_batch, sigma_batch, m_batch)\n",
    "\n",
    "        # save for backward\n",
    "        ctx.save_for_backward(xi_batch, sigma_batch, m_batch, u_proj)\n",
    "        ctx.h_func = h_func\n",
    "        return u_proj\n",
    "\n",
    "    @staticmethod\n",
    "    def backward(ctx, grad_u):\n",
    "        \"\"\"\n",
    "        Implicit differentiation via KKT system:\n",
    "          A = [[Σ⁻¹,    Jᵀ],\n",
    "               [  J,      0]]\n",
    "        Stationarity & primal feasibility → solve Aᵀ w = [grad_u; 0],\n",
    "        then ∂u/∂xi^T g = Σ⁻¹ w_u.\n",
    "        \"\"\"\n",
    "        xi_batch, sigma_batch, m_batch, u_proj = ctx.saved_tensors\n",
    "        h_func = ctx.h_func\n",
    "\n",
    "        B, p = xi_batch.shape\n",
    "        grad_xi_list = []\n",
    "\n",
    "        for i in range(B):\n",
    "            xi    = xi_batch[i]    # (p,)\n",
    "            sigma = sigma_batch[i] # (p,)\n",
    "            u     = u_proj[i]      # (p,)\n",
    "            m     = m_batch[i]\n",
    "            g     = grad_u[i]      # (p,)\n",
    "\n",
    "            # compute J_u = ∂h/∂u at the solution\n",
    "            h = h_func(u, m)\n",
    "            if h.ndim == 0:\n",
    "                h = h.unsqueeze(0)\n",
    "            J_u = jacrev(lambda u_: h_func(u_, m))(u)\n",
    "            if J_u.ndim == 1:\n",
    "                J_u = J_u.unsqueeze(0)\n",
    "            k = J_u.shape[0]\n",
    "\n",
    "            # build KKT matrix A and solve Aᵀ w = [g; 0]\n",
    "            σ_inv = 1.0 / sigma\n",
    "            # top block: [Σ⁻¹, Jᵀ]\n",
    "            top    = torch.cat([torch.diag(σ_inv), J_u.transpose(0,1)], dim=1)  # (p, p+k)\n",
    "            # bottom: [J, 0]\n",
    "            bottom = torch.cat([J_u, torch.zeros(k, k, device=u.device)], dim=1)  # (k, p+k)\n",
    "            A      = torch.cat([top, bottom], dim=0)                             # (p+k, p+k)\n",
    "\n",
    "            # right‐hand side\n",
    "            rhs = torch.cat([g, torch.zeros(k, device=g.device)], dim=0)         # (p+k,)\n",
    "\n",
    "            # solve Aᵀ w = rhs\n",
    "            w = torch.linalg.solve(A.transpose(0,1), rhs)                        # (p+k,)\n",
    "\n",
    "            # gradient wrt xi: Σ⁻¹ * w_u\n",
    "            w_u       = w[:p]\n",
    "            grad_xi   = σ_inv * w_u\n",
    "            grad_xi_list.append(grad_xi)\n",
    "\n",
    "        grad_xi_batch = torch.stack(grad_xi_list, dim=0)\n",
    "\n",
    "        # return gradients for each forward argument\n",
    "        return grad_xi_batch, None, None, None, None\n",
    "\n",
    "\n",
    "# ------------------------------------------------------------------------------\n",
    "# Usage example:\n",
    "\n",
    "# Suppose:\n",
    "#   xi_batch   = torch.randn(B, p, requires_grad=True)\n",
    "#   sigma_batch= torch.rand(B, p)\n",
    "#   m_batch    = some tensor of shape (B, …)\n",
    "#   def full_residual(u, m):  return ...  # returns a vector (k,) constraint\n",
    "\n",
    "# Then project + differentiate:\n",
    "# u_proj = ProjectWeightedImplicit.apply(xi_batch, sigma_batch, m_batch, full_residual, 30)\n",
    "# loss = some_loss(u_proj)\n",
    "# loss.backward()\n",
    "# ------------------------------------------------------------------------------\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4b1fc975-ec2a-42bd-b41d-23cb58640675",
   "metadata": {},
   "outputs": [],
   "source": [
    "u_proj"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7406d28a-ad26-4a4c-a2b9-6e1f27a7c48f",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "new_u_proj = ProjectWeightedImplicit.apply(torch.relu(mu_flat), var_flat, m_flat, full_residual, 30)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a26ec383-d9da-43b4-bee8-f6253b0c39e7",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.norm(u_proj - new_u_proj)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "170963ba-2e5c-4ab8-be48-3eb9a5156e14",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch.func import vmap, jacrev\n",
    "\n",
    "def diag_JSigmaJT(sigma: torch.Tensor, J_u: torch.Tensor, eps: float = 1e-6) -> torch.Tensor:\n",
    "    \"\"\"\n",
    "    Compute diag(J Σ J^T) efficiently for one sample, given:\n",
    "      - sigma: Tensor (p,) forming Σ = diag(sigma)\n",
    "      - J_u:   Tensor (k, p) = ∂h/∂u at u*\n",
    "    Returns:\n",
    "      Tensor (p,) containing the diagonal of J Σ J^T.\n",
    "    \"\"\"\n",
    "    # 1) build S = J_u Σ J_u^T  (shape k×k)\n",
    "    S = (J_u * sigma.unsqueeze(0)) @ J_u.T\n",
    "    S = S + eps * torch.eye(S.shape[0], device=S.device)  # regularize\n",
    "\n",
    "    # 2) invert S  (k×k)  — cheap if k ≪ p\n",
    "    S_inv = torch.linalg.inv(S)\n",
    "\n",
    "    # 3) compute vᵢ = J_u[:,i], then vᵢᵀ S⁻¹ vᵢ for each i\n",
    "    #    but we can do it all at once:\n",
    "    T = S_inv @ J_u          # (k, p)\n",
    "    quad = (J_u * T).sum(dim=0)  # (p,)\n",
    "\n",
    "    # diag(J Σ J^T) = sigma - sigma² * quad\n",
    "    return sigma - sigma**2 * quad\n",
    "\n",
    "\n",
    "def batch_diag_JSigmaJT(u_proj: torch.Tensor,\n",
    "                        sigma_batch: torch.Tensor,\n",
    "                        m_batch: torch.Tensor,\n",
    "                        h_func,\n",
    "                        eps: float = 1e-6) -> torch.Tensor:\n",
    "    \"\"\"\n",
    "    Vectorized: for each sample in the batch, compute diag(J Σ J^T).\n",
    "    \n",
    "    Args:\n",
    "      u_proj:      (B, p)  — projected outputs u*\n",
    "      sigma_batch: (B, p)  — weight diagonals\n",
    "      m_batch:     (B, …)  — extra per-sample params\n",
    "      h_func:      callable h(u, m) → Tensor (k,) or scalar\n",
    "    Returns:\n",
    "      Tensor (B, p) of diag(J Σ J^T) for each batch element.\n",
    "    \"\"\"\n",
    "    def one(sigma, u, m):\n",
    "        # 1) build Jacobian of constraints J_u at u*\n",
    "        J_u = jacrev(lambda u_: h_func(u_, m))(u)  # (k,p) or (p,) if k=1\n",
    "        if J_u.ndim == 1:\n",
    "            J_u = J_u.unsqueeze(0)\n",
    "        # 2) compute diag(J Σ J^T)\n",
    "        return diag_JSigmaJT(sigma, J_u, eps)\n",
    "\n",
    "    # vmap over batch\n",
    "    return vmap(one)(sigma_batch, u_proj, m_batch)\n",
    "\n",
    "\n",
    "# -----------------------------------------------\n",
    "# Example usage, after your forward‐pass:\n",
    "\n",
    "# u_proj = ProjectWeightedImplicit.apply(xi_batch, sigma_batch, m_batch, h_func, max_iter)\n",
    "\n",
    "# Now get the variances:\n",
    "# variances = batch_diag_JSigmaJT(u_proj, sigma_batch, m_batch, h_func)\n",
    "\n",
    "# `variances[b,i]` is (J Σ J^T)_{ii} for sample b and dimension i.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c4af1764-71b4-4c60-8959-f4f08cf5fe42",
   "metadata": {},
   "outputs": [],
   "source": [
    "variances = batch_diag_JSigmaJT(new_u_proj, var_flat, m_flat, full_residual)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0845b57b-0aeb-4ed6-a197-ee3afa9bd03d",
   "metadata": {},
   "outputs": [],
   "source": [
    "variances = variances.view(nf, nx, nt, 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1f791304-76ea-4e1f-8a8c-ce2d2937821b",
   "metadata": {},
   "outputs": [],
   "source": [
    "stds = torch.sqrt(variances)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5825100a-2722-4f1b-b053-eed6dd9098e5",
   "metadata": {},
   "outputs": [],
   "source": [
    "stds = stds.detach().cpu()\n",
    "u_proj_reshaped = u_proj_reshaped.detach().cpu()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "31f5f22d-34f1-4b5a-819d-c553113c7722",
   "metadata": {},
   "outputs": [],
   "source": [
    "# t_idx = len(t[slice(*tpred)])//2\n",
    "t_idx = 2\n",
    "\n",
    "for parameters_idx in range(0, 1, 5):\n",
    "    with torch.no_grad():\n",
    "        plt.ylabel(\"u(x,t={t:.2f})\".format(t=t[slice(*tpred)][t_idx]))\n",
    "        plt.title(\"Learning {dataset} for parameter = {k:.2f}\".format(k = x_ood_test[parameters_idx,0,0,0], dataset = dataset))\n",
    "        plt.xlabel(\"x\")\n",
    "        plt.plot(grid, u_proj_reshaped[parameters_idx,:,t_idx,0], '--', lw=2, label = \"predicted $\\mu$ and $\\pm 3\\sigma$ (varFNO)\")\n",
    "        plt.fill_between(grid, u_proj_reshaped[parameters_idx,:,t_idx,0]+3*stds[parameters_idx,:,t_idx,0], u_proj_reshaped[parameters_idx,:,t_idx,0]-3*stds[parameters_idx,:,t_idx,0], alpha=0.2)\n",
    "        plt.plot(grid, y_ood_test[parameters_idx,:,t_idx,0], color = \"green\", label = \"true\")\n",
    "        print(torch.norm(y_ood_test[parameters_idx,:,t_idx,0] - u_proj_reshaped[parameters_idx,:,t_idx,0]))\n",
    "        plt.legend()\n",
    "        # plt.ylim(-1.0,1.5)\n",
    "        plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9f02c9bc-41d9-42b5-b108-ab8af7c7bf6a",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "\n",
    "t_idx = 10\n",
    "\n",
    "for parameters_idx in range(0, 1, 5):\n",
    "    with torch.no_grad():\n",
    "        # Set compact figure size and high DPI\n",
    "        fig = plt.figure(figsize=(4.5, 2.8), dpi=300)\n",
    "\n",
    "        time_label = t[slice(*tpred)][t_idx].item()\n",
    "        param_val = x_ood_test[parameters_idx, 0, 0, 0].item()\n",
    "        true_vals = y_ood_test[parameters_idx, :, t_idx, 0]\n",
    "        pred_mu_conserv = new_mu[parameters_idx, :, t_idx, 0]\n",
    "        pred_std_conserv = new_std[parameters_idx, :, t_idx, 0]\n",
    "        pred_mu_e2e = u_proj_reshaped[parameters_idx, :, t_idx, 0]\n",
    "        pred_std_e2e = stds[parameters_idx, :, t_idx, 0]\n",
    "\n",
    "        # Plot true solution\n",
    "        plt.plot(grid, true_vals, color=\"black\", lw=1.5, label=\"True\", zorder=3)\n",
    "\n",
    "        # ProbConserv\n",
    "        plt.plot(grid, pred_mu_conserv, '--', lw=1.2, color=\"#ef233c\",\n",
    "                 label=r\"ProbConserv\", zorder=2)\n",
    "        plt.fill_between(grid,\n",
    "                         pred_mu_conserv + 3 * pred_std_conserv,\n",
    "                         pred_mu_conserv - 3 * pred_std_conserv,\n",
    "                         color=\"#ef233c\", alpha=0.3, label=\"_nolegend_\", zorder=1)\n",
    "\n",
    "        # ProbHardE2E\n",
    "        plt.plot(grid, pred_mu_e2e, '--', lw=1.5, color=\"#3a86ff\",\n",
    "                 label=r\"ProbHardE2E\", zorder=4)\n",
    "        plt.fill_between(grid,\n",
    "                         pred_mu_e2e + 3 * pred_std_e2e,\n",
    "                         pred_mu_e2e - 3 * pred_std_e2e,\n",
    "                         color=\"#3a86ff\", alpha=0.3, label=\"_nolegend_\", zorder=3)\n",
    "\n",
    "        # Labels\n",
    "        plt.xlabel(\"x\", fontsize=8)\n",
    "        plt.ylabel(r\"$u(x, t={:.2f})$\".format(time_label), fontsize=8)\n",
    "\n",
    "        # Ticks and grid\n",
    "        plt.xticks(fontsize=7)\n",
    "        plt.yticks(fontsize=7)\n",
    "        plt.grid(True, linestyle=\"--\", alpha=0.4)\n",
    "\n",
    "        # Legend\n",
    "        plt.legend(fontsize=7, loc=\"upper right\", frameon=False)\n",
    "\n",
    "        # Layout for Overleaf\n",
    "        plt.tight_layout(pad=0.3)\n",
    "        plt.show()\n",
    "\n",
    "        print(\"L2 error:\", torch.norm(true_vals - pred_mu_e2e).item())\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e2918b7b-cab2-4da1-8179-9d49cdd14170",
   "metadata": {},
   "outputs": [],
   "source": [
    "# t_idx = len(t[slice(*tpred)])//2\n",
    "t_idx = 2\n",
    "\n",
    "for parameters_idx in range(0, 1, 5):\n",
    "    with torch.no_grad():\n",
    "        plt.ylabel(\"u(x,t={t:.2f})\".format(t=t[slice(*tpred)][t_idx]))\n",
    "        plt.title(\"Learning {dataset} for parameter = {k:.2f}\".format(k = x_ood_test[parameters_idx,0,0,0], dataset = dataset))\n",
    "        plt.xlabel(\"x\")\n",
    "        plt.plot(grid, new_mu[parameters_idx,:,t_idx,0], '--', lw=2, label = \"predicted $\\mu$ and $\\pm 3\\sigma$ (probconserv)\")\n",
    "        plt.fill_between(grid, new_mu[parameters_idx,:,t_idx,0]+3*new_std[parameters_idx,:,t_idx,0], new_mu[parameters_idx,:,t_idx,0]-3*new_std[parameters_idx,:,t_idx,0], alpha=0.2)\n",
    "        plt.plot(grid, y_train[parameters_idx,:,t_idx,0], color = \"green\", label = \"true\")\n",
    "        \n",
    "        plt.plot(grid, u_proj_reshaped[parameters_idx,:,t_idx,0], '--', lw=2, label = \"predicted $\\mu$ and $\\pm 3\\sigma$ (probharde2e)\")\n",
    "        plt.fill_between(grid, u_proj_reshaped[parameters_idx,:,t_idx,0]+3*stds[parameters_idx,:,t_idx,0], u_proj_reshaped[parameters_idx,:,t_idx,0]-3*stds[parameters_idx,:,t_idx,0], alpha=0.2)\n",
    "        plt.plot(grid, y_train[parameters_idx,:,t_idx,0], color = \"green\", label = \"true\")\n",
    "        \n",
    "        print(torch.norm(y_train[parameters_idx,:,t_idx,0] - new_mu[parameters_idx,:,t_idx,0]))\n",
    "        plt.legend()\n",
    "        # plt.ylim(-1.0,1.5)\n",
    "        plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2614e9d3-c4bb-4f3a-8ed6-bb81c2881208",
   "metadata": {},
   "outputs": [],
   "source": [
    "var_flat.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a35cc21f-5ab7-47a0-9460-ac32a0975993",
   "metadata": {},
   "outputs": [],
   "source": [
    "variances.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6659e913-9ab1-45fd-8c9b-e22aac0f3134",
   "metadata": {},
   "outputs": [],
   "source": [
    "u_proj_reshaped.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f53de336-182c-4ed7-9bde-3a493a1f3ddc",
   "metadata": {},
   "outputs": [],
   "source": [
    "t_idx = 10\n",
    "plt.plot(y_train[0,:,t_idx,0],lw=4,label=\"true solution\")\n",
    "plt.plot(mu[0,:,t_idx,0].cpu().detach(),'--',lw=2,label=\"unconstrained\")\n",
    "plt.plot(new_mu[0,:,t_idx,0].cpu().detach(),label=\"probconserv\")\n",
    "plt.plot(u_proj_reshaped[0,:,t_idx,0].cpu().detach(),'-.',lw=2,label=\"probharde2e\")\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "93b5f5a1-f5a7-4b0c-9b69-f963d3756103",
   "metadata": {},
   "outputs": [],
   "source": [
    "# t_idx = 2\n",
    "# plt.plot(y_train[0,:,t_idx,0],lw=4,label=\"true solution\")\n",
    "# plt.plot(mu[0,:,t_idx,0].cpu().detach(),'--',lw=2,label=\"unconstrained\")\n",
    "# plt.plot(new_mu[0,:,t_idx,0].cpu().detach(),label=\"probconserv\")\n",
    "# plt.plot(new_u_proj_reshaped[0,:,t_idx,0].cpu().detach(),'-.',lw=2,label=\"probharde2e\")\n",
    "# plt.legend()\n",
    "# plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e995547a-8fdd-4a15-9096-fdd970b9b220",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.norm(new_u_proj - u_proj)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "16a81126-a48a-4424-8290-6a9cd29b85f6",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b5f2258f-3c7f-4302-b420-fca2acb5541a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# --- Constraint function ---\n",
    "def h(u):\n",
    "    return u[0]**2 + u[1]**2 - 1.0\n",
    "\n",
    "# Check backprop\n",
    "xi_batch = torch.tensor([[2.0, 1.0],\n",
    "                         [0.1, 0.1],\n",
    "                         [-1.5, -0.5]], dtype=torch.float32, device=device)\n",
    "xi_leaf = xi_batch.clone().detach().requires_grad_(True)\n",
    "\n",
    "u_projected = fast_project_batched(mu_flat, h)\n",
    "u_proj = fast_project_weighted(mu_flat, var_flat, h_func=h)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "321130f8-924f-4198-8715-49e8f4797d4b",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.norm(utils.compute_sampling_crps_by_example(mu, var, y,nbins=500))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6df72eff-6964-4de7-9326-6b9458282089",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.norm(utils.compute_sampling_crps_by_example(new_mu, new_std.square(), y,nbins=500))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "260d83c1-f32c-4693-8694-35b22ab1e6a4",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.norm(utils.compute_sampling_crps_by_example(u_proj_reshaped, variances.cpu(), y,nbins=500))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0a2ba91c-01a6-4903-8b3d-89cd9dc39757",
   "metadata": {},
   "outputs": [],
   "source": [
    "t_grid"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "633aa3b5-ea6a-4acf-9b78-63adc39d6418",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.norm(mu-u_proj_reshaped)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "45271205-293d-4a07-8c1f-3fa249cfe744",
   "metadata": {},
   "source": [
    "## Experiments to check CRPS (sampling) and CRPS without sampling"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f3f1485f-07fb-4039-a0c4-d0480edbd8bd",
   "metadata": {},
   "outputs": [],
   "source": [
    "model.loss_func(out, y.to(device))/len(out[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7f826c2e-eee5-4b4a-abc0-daf2e505aa50",
   "metadata": {},
   "outputs": [],
   "source": [
    "out[0].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dd27a2af-c705-4621-b36b-614bf61fc40c",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4dd459b4-d1b8-4d83-8540-074bb8374287",
   "metadata": {},
   "outputs": [],
   "source": [
    "crps_by_sample = utils.compute_sampling_crps_by_example(mu, var, y,nbins=500)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "400f807a-3676-42a2-82fd-f0ee6999c64e",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.mean(crps_by_sample)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "913e8446-4753-427f-8d64-d69dbed2b866",
   "metadata": {},
   "outputs": [],
   "source": [
    "std = torch.sqrt(var)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9e8ad8c6-28e9-45ff-a9e3-a7c828ed5f2a",
   "metadata": {},
   "outputs": [],
   "source": [
    "out = model(x_train.to(device))\n",
    "x = train_loader.dataset.tensors[0]\n",
    "y = train_loader.dataset.tensors[1]\n",
    "mass_rhs_func = dataset_class.get_mass_rhs_func(x=x)\n",
    "\n",
    "new_mu, new_std, _, mass_rhs = probconserv.apply_constraint(\n",
    "    mu=mu[:, :, :, 0], \n",
    "    std=std[:, :, :, 0], \n",
    "    mass_rhs_func=mass_rhs_func, \n",
    "    t=t, \n",
    "    tpred=tpred, \n",
    "    grid_train=grid, \n",
    "    precis_g=np.inf,\n",
    "    second_deriv_alpha=None,\n",
    ")\n",
    "new_mu = new_mu[:, :, :, None]\n",
    "new_std = new_std[:, :, :, None]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "64b75faa-0ce0-47f9-ae71-2aae64386daf",
   "metadata": {},
   "outputs": [],
   "source": [
    "t_idx = 1\n",
    "parameter_idx = 0\n",
    "with torch.no_grad():\n",
    "    plt.ylabel(\"u(x,t={t:.2f})\".format(t=t[slice(*tpred)][t_idx]))\n",
    "    plt.title(\"Learning Heat Equation for parameter = {k:.2f}\".format(k = x_train[parameter_idx,0,0,0]))\n",
    "    plt.xlabel(\"x\")\n",
    "    plt.plot(grid, mu[parameter_idx,:,t_idx,0], '--', lw=2, label = \"predicted $\\mu$ and $\\pm 3\\sigma$ (varFNO)\")\n",
    "    plt.fill_between(grid, mu[parameter_idx,:,t_idx,0]+3*std[parameter_idx,:,t_idx,0], mu[parameter_idx,:,t_idx,0]-3*std[parameter_idx,:,t_idx,0], alpha=0.2)\n",
    "    plt.plot(grid, y_train[parameter_idx,:,t_idx,0], color = \"green\", label = \"true\")\n",
    "    plt.legend(loc=\"upper right\")\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fbeef4c3-208f-490b-9f76-24ba47485adc",
   "metadata": {},
   "outputs": [],
   "source": [
    "t_idx = 1\n",
    "\n",
    "for parameters_idx in range(0, 1, 5):\n",
    "    with torch.no_grad():\n",
    "        plt.ylabel(\"u(x,t={t:.2f})\".format(t=t[slice(*tpred)][t_idx]))\n",
    "        plt.title(\"Learning {dataset} for parameter = {k:.2f}\".format(k = x_train[parameters_idx,0,0,0], dataset = dataset))\n",
    "        plt.xlabel(\"x\")\n",
    "        plt.plot(grid, new_mu[parameters_idx,:,t_idx,0], '--', lw=2, label = \"predicted $\\mu$ and $\\pm 3\\sigma$ (varFNO)\")\n",
    "        plt.fill_between(grid, new_mu[parameters_idx,:,t_idx,0]+3*new_std[parameters_idx,:,t_idx,0], new_mu[parameters_idx,:,t_idx,0]-3*new_std[parameters_idx,:,t_idx,0], alpha=0.2)\n",
    "        plt.plot(grid, y_train[parameters_idx,:,t_idx,0], color = \"green\", label = \"true\")        \n",
    "        plt.legend()\n",
    "        # plt.ylim(-1.0,1.5)\n",
    "        plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5ea3ace5-4141-4975-9823-0ec6d7aab7cd",
   "metadata": {},
   "outputs": [],
   "source": [
    "ucons_stats_train = utils.compute_all_metrics_avg((mu, torch.square(std)), y_train, {})\n",
    "ucons_stats_train[\"nMeRCI_all\"] = utils.compute_nMeRCI(mu, torch.square(std), y_train).item()\n",
    "ucons_stats_train[\"rmsce_all\"] = utils.compute_rmsce(mu, torch.square(std), y_train).item()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bbc1aeab-f5a3-45e5-8f17-c47458e866f2",
   "metadata": {},
   "outputs": [],
   "source": [
    "ucons_stats_train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6090d363-b486-478d-9694-278e7526154d",
   "metadata": {},
   "outputs": [],
   "source": [
    "probconserv_stats_train = utils.compute_all_metrics_avg((new_mu, torch.square(new_std)), y_train, {})\n",
    "probconserv_stats_train[\"nMeRCI_all\"] = utils.compute_nMeRCI(new_mu, torch.square(new_std), y_train).item()\n",
    "probconserv_stats_train[\"rmsce_all\"] = utils.compute_rmsce(new_mu, torch.square(new_std), y_train).item()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fbc2b1ac-a8a6-4991-bc08-aad3f87982e0",
   "metadata": {},
   "outputs": [],
   "source": [
    "cerr = (probconserv.get_empirical_mass_rhs(mu[:, :,  :, 0]) - mass_rhs).abs().sum(dim=-1)\n",
    "new_cerr = (probconserv.get_empirical_mass_rhs(new_mu[:, :, :, 0]) - mass_rhs).abs().sum(dim=-1)\n",
    "\n",
    "ucons_stats_train[\"cerr_by_example\"] = cerr.tolist()\n",
    "ucons_stats_train[\"mcerr\"] = cerr.mean().item()\n",
    "probconserv_stats_train[\"cerr_by_example\"] = new_cerr.tolist()\n",
    "probconserv_stats_train[\"mcerr\"] = new_cerr.mean().item()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aa190002-f493-40a0-857d-79de73857da7",
   "metadata": {},
   "outputs": [],
   "source": [
    "out = model(x_ood_test.to(device))\n",
    "\n",
    "x = ood_test_loader.dataset.tensors[0]\n",
    "y = ood_test_loader.dataset.tensors[1]\n",
    "mass_rhs_func = dataset_class.get_mass_rhs_func(x=x)\n",
    "if model.probconserv:\n",
    "    _mu, _var, = out[0].cpu(), out[1].cpu()\n",
    "    _std = torch.sqrt(_var)\n",
    "    mass_rhs_func = dataset_class.get_mass_rhs_func(x=x)\n",
    "    new_mu, new_std, _, mass_rhs = probconserv.apply_constraint(\n",
    "                                                    mu=_mu[:, :, :, 0], \n",
    "                                                    std=_std[:, :, :, 0], \n",
    "                                                    mass_rhs_func=mass_rhs_func, \n",
    "                                                    t=t, \n",
    "                                                    tpred=tpred, \n",
    "                                                    grid_train=grid, \n",
    "                                                    precis_g=np.inf,\n",
    "                                                    second_deriv_alpha=None,\n",
    "                                                    )\n",
    "    out = (new_mu.unsqueeze(-1), torch.square(new_std).unsqueeze(-1))\n",
    "\n",
    "mu, var, = out[0].cpu(), out[1].cpu()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "39e9bde3-2eba-4801-aa30-88c2c043ec20",
   "metadata": {},
   "outputs": [],
   "source": [
    "std = torch.sqrt(var)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2384c2df-e413-4a42-8107-977afa42306e",
   "metadata": {},
   "outputs": [],
   "source": [
    "t_idx = 1\n",
    "parameter_idx = 0\n",
    "with torch.no_grad():\n",
    "    plt.ylabel(\"u(x,t={t:.2f})\".format(t=t[slice(*tpred)][t_idx]))\n",
    "    plt.title(\"Learning Heat Equation for parameter = {k:.2f}\".format(k = x_ood_test[parameter_idx,0,0,0]))\n",
    "    plt.xlabel(\"x\")\n",
    "    plt.plot(grid, mu[parameter_idx,:,t_idx,0], '--', lw=2, label = \"predicted $\\mu$ and $\\pm 3\\sigma$ (varFNO)\")\n",
    "    plt.fill_between(grid, mu[parameter_idx,:,t_idx,0]+3*std[parameter_idx,:,t_idx,0], mu[parameter_idx,:,t_idx,0]-3*std[parameter_idx,:,t_idx,0], alpha=0.2)\n",
    "    plt.plot(grid, y_ood_test[parameter_idx,:,t_idx,:], color = \"green\", label = \"true\")\n",
    "    plt.legend(loc=\"upper right\")\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bbca5941-d33a-4608-b00c-99f5024fd3aa",
   "metadata": {},
   "outputs": [],
   "source": [
    "x = ood_test_loader.dataset.tensors[0]\n",
    "y = ood_test_loader.dataset.tensors[1]\n",
    "mass_rhs_func = dataset_class.get_mass_rhs_func(x=x)\n",
    "new_mu, new_std, _, mass_rhs = probconserv.apply_constraint(\n",
    "    mu=mu[:, :, :, 0], \n",
    "    std=std[:, :, :, 0], \n",
    "    mass_rhs_func=mass_rhs_func, \n",
    "    t=t, \n",
    "    tpred=tpred, \n",
    "    grid_train=grid, \n",
    "    precis_g=np.inf,\n",
    "    second_deriv_alpha=None,\n",
    ")\n",
    "new_mu = new_mu[:, :, :, None]\n",
    "new_std = new_std[:, :, :, None]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6b809290-3cb4-4d19-a03a-ed6435a705a7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# t_idx = len(t[slice(*tpred)])//2\n",
    "t_idx = 1\n",
    "\n",
    "for parameters_idx in range(0, 1, 5):\n",
    "    with torch.no_grad():\n",
    "        plt.ylabel(\"u(x,t={t:.2f})\".format(t=t[slice(*tpred)][t_idx]))\n",
    "        plt.title(\"Learning {dataset} for parameter = {k:.2f}\".format(k = x_ood_test[parameters_idx,0,0,0], dataset = dataset))\n",
    "        plt.xlabel(\"x\")\n",
    "        plt.plot(grid, new_mu[parameters_idx,:,t_idx,0], '--', lw=2, label = \"predicted $\\mu$ and $\\pm 3\\sigma$ (varFNO)\")\n",
    "        plt.fill_between(grid, new_mu[parameters_idx,:,t_idx,0]+3*new_std[parameters_idx,:,t_idx,0], new_mu[parameters_idx,:,t_idx,0]-3*new_std[parameters_idx,:,t_idx,0], alpha=0.2)\n",
    "        plt.plot(grid, y_ood_test[parameters_idx,:,t_idx,0], color = \"green\", label = \"true\")\n",
    "        print(torch.norm(y_ood_test[parameters_idx,:,t_idx,0] - new_mu[parameters_idx,:,t_idx,0]))\n",
    "        plt.legend()\n",
    "        # plt.ylim(-1.0,1.5)\n",
    "        plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f5762b3e-644a-4cae-ac0f-5ab97ce326bf",
   "metadata": {},
   "outputs": [],
   "source": [
    "ucons_stats_test = utils.compute_all_metrics_avg((mu, torch.square(std)), y_ood_test, {})\n",
    "ucons_stats_test[\"nMeRCI_all\"] = utils.compute_nMeRCI(mu, torch.square(std), y_ood_test).item()\n",
    "ucons_stats_test[\"rmsce_all\"] = utils.compute_rmsce(mu, torch.square(std), y_ood_test).item()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "053bdcfd-e155-4a61-9532-50a72beda9c8",
   "metadata": {},
   "outputs": [],
   "source": [
    "probconserv_stats_test = utils.compute_all_metrics_avg((new_mu, torch.square(new_std)), y_ood_test, {})\n",
    "probconserv_stats_test[\"nMeRCI_all\"] = utils.compute_nMeRCI(new_mu, torch.square(new_std), y_ood_test).item()\n",
    "probconserv_stats_test[\"rmsce_all\"] = utils.compute_rmsce(new_mu, torch.square(new_std), y_ood_test).item()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "973a1ea0-a774-4d8a-9947-fabe05d76bfc",
   "metadata": {},
   "outputs": [],
   "source": [
    "cerr = (probconserv.get_empirical_mass_rhs(mu[:, :,  :, 0]) - mass_rhs).abs().sum(dim=-1)\n",
    "new_cerr = (probconserv.get_empirical_mass_rhs(new_mu[:, :, :, 0]) - mass_rhs).abs().sum(dim=-1)\n",
    "\n",
    "ucons_stats_test[\"cerr_by_example\"] = cerr.tolist()\n",
    "ucons_stats_test[\"mcerr\"] = cerr.mean().item()\n",
    "probconserv_stats_test[\"cerr_by_example\"] = new_cerr.tolist()\n",
    "probconserv_stats_test[\"mcerr\"] = new_cerr.mean().item()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "da438412-1398-430d-ba74-82db34174f39",
   "metadata": {},
   "outputs": [],
   "source": [
    "ucons_stats_train"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "09f0df61-ed39-4fe8-8463-87a363b11ef6",
   "metadata": {},
   "source": [
    "## E2E Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 130,
   "id": "c36a3814-61c7-4f14-a367-bad3828213de",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = \"cuda\" if torch.cuda.is_available() else \"cpu\" \n",
    "experiment_name = \"trial\"\n",
    "# print(f\"Experiment: {experiment_name}\")\n",
    "# print(args)\n",
    "save_args = utils.filter_config(args, [\"generate\", \"--no_train\", \"--ood_dataset_params\", \"--tplot\"], mode=\"remove\")  # Also removes \".\" keys\n",
    "\n",
    "is_train = not bool(args[\"--no_train\"])\n",
    "\n",
    "# Parameters\n",
    "n_x = int(args[\"--grid_len\"])\n",
    "n_t = int(args[\"--time_len\"])\n",
    "n_samples = int(args[\"--n_samples\"])\n",
    "n_train = int(0.8 * n_samples)\n",
    "n_valid = int(0.2 * n_samples)\n",
    "n_test = n_samples // 2\n",
    "\n",
    "is_markov = False\n",
    "\n",
    "dataset = args[\"--dataset\"]\n",
    "dataset_params = [float(val) for val in args[\"--dataset_params\"].split(\",\")]\n",
    "train_ood_dataset_params = [float(val) for val in args[\"--train_ood_dataset_params\"].split(\",\")]\n",
    "ood_dataset_params = train_ood_dataset_params\n",
    "if not is_train:\n",
    "    ood_dataset_params = [float(val) for val in args[\"--ood_dataset_params\"].split(\",\")]\n",
    "\n",
    "tpred = [int(val) for val in args[\"--predict_time\"].split(\",\")]\n",
    "\n",
    "fno_modes = int(args[\"--fno_modes\"])\n",
    "fno_width = int(args[\"--fno_width\"])\n",
    "\n",
    "batch_size = int(args[\"--batch_size\"])\n",
    "lr = float(args[\"--lr\"])\n",
    "epochs = int(args[\"--epochs\"])\n",
    "step_size = 50\n",
    "gamma = 0.5\n",
    "# ################\n",
    "\n",
    "# Set seed\n",
    "utils.set_seed(int(args[\"--seed\"]))\n",
    "\n",
    "# Generate dataset\n",
    "if dataset.lower() == \"HeatEquation_1D\".lower():\n",
    "    t = torch.linspace(0, 1, n_t)\n",
    "    grid = torch.linspace(0, 2 * np.pi, n_x)\n",
    "    dataset_class = HeatEquation_1D\n",
    "elif dataset.lower() == \"PME_1D\".lower():\n",
    "    t = torch.linspace(0, 1, n_t)\n",
    "    grid = torch.linspace(0, 1, n_x)\n",
    "    dataset_class = PME_1D\n",
    "elif dataset.lower() == \"StefanPME_1D\".lower():\n",
    "    t = torch.linspace(0, 1, n_t)\n",
    "    grid = torch.linspace(0, 1, n_x)\n",
    "    dataset_class = StefanPME_1D\n",
    "elif dataset.lower() == \"LinearAdvection_1D\".lower():\n",
    "    t = torch.linspace(0, 1, n_t)\n",
    "    grid = torch.linspace(0, 1, n_x)\n",
    "    dataset_class = LinearAdvection_1D\n",
    "else:\n",
    "    raise NotImplementedError\n",
    "\n",
    "t_sliced = t[slice(*tpred)]\n",
    "T = len(t_sliced)\n",
    "\n",
    "def get_xy_from_pu(p, u, is_markov=False):\n",
    "    T = u.shape[2]\n",
    "    #TODO: What does is_markov do here?\n",
    "    if is_markov:\n",
    "        x0, y0 = p, u\n",
    "        \n",
    "        y0_vectorized = rearrange(y0[:, :, 0:T-1], \"nf nx nt 1 -> (nf nt) nx 1\")\n",
    "        x0 = repeat(x0, \"nf nx 1 -> (nf nt) nx 1\", nt=T-1)\n",
    "        x = torch.cat([x0, y0_vectorized], dim=-1)\n",
    "        \n",
    "        y = rearrange(y0[:, :, 1:T], \"nf nx nt 1 -> (nf nt) nx 1\")\n",
    "    else:\n",
    "        x, y = p, u\n",
    "        x = repeat(x, \"nf nx 1 -> nf nx T 1\", T=T)\n",
    "    return x, y\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 131,
   "id": "8d06a5a3-4116-421c-bd9a-e01ead11928c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Here 160\n",
      "torch.Size([160, 100, 1]) torch.Size([160, 100, 20, 1])\n",
      "torch.Size([160, 100, 20, 1]) torch.Size([160, 100, 20, 1])\n"
     ]
    }
   ],
   "source": [
    "if is_train:\n",
    "    # Train data\n",
    "    print(\"Here\", n_train)\n",
    "    a, u, p = dataset_class.generate_dataset(n_train, grid, t, tpred, *dataset_params)\n",
    "    print(a.shape, u.shape)\n",
    "    x_train, y_train = get_xy_from_pu(p, u, is_markov=is_markov)\n",
    "\n",
    "    # Validation data\n",
    "    a, u, p = dataset_class.generate_dataset(n_valid, grid, t, tpred, *dataset_params)\n",
    "    x_valid, y_valid = get_xy_from_pu(p, u, is_markov=is_markov)\n",
    "\n",
    "    # In-distribution test data\n",
    "    a, u, p = dataset_class.generate_dataset(n_test, grid, t, tpred, *dataset_params)\n",
    "    x_id_test, y_id_test = get_xy_from_pu(p, u, is_markov=is_markov)\n",
    "\n",
    "    # Out-of-distribution inputs only\n",
    "    a, u, p = dataset_class.generate_dataset(n_test, grid, t, tpred, *train_ood_dataset_params)\n",
    "    x_ood_test, y_ood_test = get_xy_from_pu(p, u, is_markov=is_markov)\n",
    "\n",
    "    # Data loaders\n",
    "    train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_train, y_train), \n",
    "                                            batch_size=batch_size, shuffle=True)\n",
    "    valid_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_valid, y_valid), \n",
    "                                            batch_size=batch_size, shuffle=False)\n",
    "    id_test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_id_test, y_id_test), \n",
    "                                            batch_size=batch_size, shuffle=False)\n",
    "    ood_test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_ood_test, y_ood_test), \n",
    "                                            batch_size=batch_size, shuffle=False)\n",
    "else:\n",
    "    # OOD test data\n",
    "    a, u, p = dataset_class.generate_dataset(n_test, grid, t, tpred, *ood_dataset_params)\n",
    "    x_ood_test, y_ood_test = get_xy_from_pu(p, u, is_markov=is_markov)\n",
    "    ood_test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_ood_test, y_ood_test), \n",
    "                                            batch_size=batch_size, shuffle=False)\n",
    "\n",
    "print(x_train.shape, y_train.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 132,
   "id": "95fa4119-7947-412c-9c9b-e79e9d97f6b1",
   "metadata": {},
   "outputs": [],
   "source": [
    "constraint_context = {\n",
    "    \"t\": t.to(device),\n",
    "    \"tpred\": torch.tensor(tpred).to(device),\n",
    "    \"grid_train\": grid.to(device),\n",
    "    \"dataset_class\": dataset_class\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e13b2903-c201-478d-b374-b1b80f813d72",
   "metadata": {},
   "outputs": [],
   "source": [
    "uq = False\n",
    "model_name = args[\"--model\"]\n",
    "n_models = 1\n",
    "fno_modes2 = min(fno_modes, 12)\n",
    "if args[\"--model\"].lower() == \"FNO2d\".lower():\n",
    "    FNO2d_params = {\"modes1\": fno_modes, \"modes2\": fno_modes2, \"width\": fno_width, \"output_var\": True}\n",
    "    model = FNO2d(**FNO2d_params).to(device)\n",
    "elif args[\"--model\"].lower().startswith(\"EnsembleFNO2d\".lower()):\n",
    "    FNO2d_params = {\"modes1\": fno_modes, \"modes2\": fno_modes2, \"width\": fno_width}\n",
    "    n_models = int(args[\"--m.n_models\"])\n",
    "    utils.filter_config(args, [\"--m.n_models\"], mode=\"add\", new_config=save_args)\n",
    "    model = EnsembleNO(base_model_class=FNO2d, base_model_params=FNO2d_params, n_models=n_models)\n",
    "    uq = True\n",
    "elif args[\"--model\"].lower().startswith(\"BayesianFNO2d\".lower()):\n",
    "    FNO2d_params = {\"modes1\": fno_modes, \"modes2\": fno_modes2, \"width\": fno_width}\n",
    "    model = BayesianNO(base_model_class=FNO2d, base_model_params=FNO2d_params)\n",
    "    uq = True\n",
    "elif args[\"--model\"].lower().startswith(\"MCDropoutFNO2d\".lower()):\n",
    "    FNO2d_params = {\"modes1\": fno_modes, \"modes2\": fno_modes2, \"width\": fno_width}\n",
    "    dropout = float(args[\"--m.drop_prob\"])\n",
    "    n_dropouts = int(args[\"--m.n_models\"])\n",
    "    utils.filter_config(args, [\"--m.n_models\", \"--m.drop_prob\"], mode=\"add\", new_config=save_args)\n",
    "    model = MCDropoutNO(base_model_class=FNO2d, base_model_params=FNO2d_params, dropout=dropout, n_dropouts=n_dropouts)\n",
    "    uq = True\n",
    "elif args[\"--model\"].lower().startswith(\"OutputVarFNO2d\".lower()):\n",
    "    FNO2d_params = {\"modes1\": fno_modes, \"modes2\": fno_modes2, \"width\": fno_width}\n",
    "    #model = OutputVarNO(base_model_class=FNO2d, probconserv=False, base_model_params=FNO2d_params)\n",
    "    model = ProbHardE2E(base_model_class=FNO2d, probconserv=False, base_model_params=FNO2d_params, constraint_context=constraint_context, noneq_constraint_e2e=True)\n",
    "    uq = True\n",
    "elif args[\"--model\"].lower().startswith(\"DiverseFNO2d\".lower()):\n",
    "    FNO2d_params = {\"modes1\": fno_modes, \"modes2\": fno_modes2, \"width\": fno_width}\n",
    "    lam = float(args[\"--m.reg_strength\"])\n",
    "    reg_type = args[\"--m.reg_type\"]\n",
    "    n_models = int(args[\"--m.n_models\"])\n",
    "    n_regularize = int(args[\"--m.n_regularize\"])\n",
    "    utils.filter_config(args, [\"--m.n_models\", \"--m.reg_strength\", \"--m.reg_type\", \"--m.n_regularize\"], mode=\"add\", new_config=save_args)\n",
    "    model = DiverseFNO2d(reg_loss=reg_type, n_outputs=n_models, bias_last=False, lam=lam, n_regularize=n_regularize, **FNO2d_params).to(device)\n",
    "    uq = True\n",
    "else:\n",
    "    raise NotImplementedError"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 134,
   "id": "dd02c190-49e0-490f-85a4-1189103cef61",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([0.0000, 0.0101, 0.0202, 0.0303, 0.0404, 0.0505, 0.0606, 0.0707, 0.0808,\n",
       "        0.0909, 0.1010, 0.1111, 0.1212, 0.1313, 0.1414, 0.1515, 0.1616, 0.1717,\n",
       "        0.1818, 0.1919, 0.2020, 0.2121, 0.2222, 0.2323, 0.2424, 0.2525, 0.2626,\n",
       "        0.2727, 0.2828, 0.2929, 0.3030, 0.3131, 0.3232, 0.3333, 0.3434, 0.3535,\n",
       "        0.3636, 0.3737, 0.3838, 0.3939, 0.4040, 0.4141, 0.4242, 0.4343, 0.4444,\n",
       "        0.4545, 0.4646, 0.4747, 0.4848, 0.4949, 0.5051, 0.5152, 0.5253, 0.5354,\n",
       "        0.5455, 0.5556, 0.5657, 0.5758, 0.5859, 0.5960, 0.6061, 0.6162, 0.6263,\n",
       "        0.6364, 0.6465, 0.6566, 0.6667, 0.6768, 0.6869, 0.6970, 0.7071, 0.7172,\n",
       "        0.7273, 0.7374, 0.7475, 0.7576, 0.7677, 0.7778, 0.7879, 0.7980, 0.8081,\n",
       "        0.8182, 0.8283, 0.8384, 0.8485, 0.8586, 0.8687, 0.8788, 0.8889, 0.8990,\n",
       "        0.9091, 0.9192, 0.9293, 0.9394, 0.9495, 0.9596, 0.9697, 0.9798, 0.9899,\n",
       "        1.0000])"
      ]
     },
     "execution_count": 134,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "t"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 135,
   "id": "11c859b7-92f8-4125-b562-fad7c2a4683c",
   "metadata": {},
   "outputs": [],
   "source": [
    "import time"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 136,
   "id": "4945a696-f59c-4d0b-a800-08616a447547",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Updating Sigma now\n",
      "Epoch 0: Train loss=-2038.221997, Validation loss=-2317.116699 (saved)\n",
      "Epoch 1: Train loss=-2625.459814, Validation loss=-3101.577148 (saved)\n",
      "Epoch 2: Train loss=-3668.263696, Validation loss=-4773.478516 (saved)\n",
      "Epoch 3: Train loss=-6973.070459, Validation loss=-220.462415 \n",
      "Epoch 4: Train loss=-6949.894989, Validation loss=-7897.492969 (saved)\n",
      "Epoch 5: Train loss=-8517.266602, Validation loss=-8451.830859 (saved)\n",
      "Epoch 6: Train loss=-9038.646973, Validation loss=-9794.991406 (saved)\n",
      "Epoch 7: Train loss=-10817.341016, Validation loss=-12355.973438 (saved)\n",
      "Epoch 8: Train loss=-13495.757617, Validation loss=3719.760352 \n",
      "Epoch 9: Train loss=-9181.301715, Validation loss=-12104.085938 \n",
      "Epoch 10: Train loss=-11348.513086, Validation loss=-11309.700000 \n",
      "Epoch 11: Train loss=-11606.242676, Validation loss=-12089.951172 \n",
      "Epoch 12: Train loss=-12896.941406, Validation loss=-14082.765625 (saved)\n",
      "Epoch 13: Train loss=-15275.534375, Validation loss=-16888.144531 (saved)\n",
      "Epoch 14: Train loss=-17862.850391, Validation loss=-18009.008594 (saved)\n",
      "Epoch 15: Train loss=-16293.049414, Validation loss=-17987.829688 \n",
      "Epoch 16: Train loss=-17477.007031, Validation loss=-17572.114844 \n",
      "Epoch 17: Train loss=-18008.173047, Validation loss=-18666.041406 (saved)\n",
      "Epoch 18: Train loss=-19237.147461, Validation loss=-19971.125781 (saved)\n",
      "Epoch 19: Train loss=-17793.444629, Validation loss=-15194.638281 \n",
      "Epoch 20: Train loss=-17609.040039, Validation loss=-17788.156250 \n",
      "Epoch 21: Train loss=-17664.063867, Validation loss=-17823.154687 \n",
      "Epoch 22: Train loss=-18106.350195, Validation loss=-18515.067188 \n",
      "Epoch 23: Train loss=-18844.712500, Validation loss=-19260.553125 \n",
      "Epoch 24: Train loss=-19587.252344, Validation loss=-19984.578906 (saved)\n",
      "Epoch 25: Train loss=-20245.589844, Validation loss=-20609.657031 (saved)\n",
      "Epoch 26: Train loss=-16260.097510, Validation loss=-17037.403125 \n",
      "Epoch 27: Train loss=-17148.801953, Validation loss=-16603.833594 \n",
      "Epoch 28: Train loss=-16459.715820, Validation loss=-16704.667187 \n",
      "Epoch 29: Train loss=-17034.477344, Validation loss=-17354.138281 \n",
      "Epoch 30: Train loss=-17660.968359, Validation loss=-18019.567969 \n",
      "Epoch 31: Train loss=-18258.843555, Validation loss=-18591.296094 \n",
      "Epoch 32: Train loss=-18853.052539, Validation loss=-19194.632031 \n",
      "Epoch 33: Train loss=-19460.989063, Validation loss=-19802.495312 \n",
      "Epoch 34: Train loss=-20057.888672, Validation loss=-20439.690625 \n",
      "Epoch 35: Train loss=-20628.125000, Validation loss=-20837.435156 (saved)\n",
      "Epoch 36: Train loss=-10817.612500, Validation loss=-16814.126563 \n",
      "Epoch 37: Train loss=-16046.473242, Validation loss=-15800.821875 \n",
      "Epoch 38: Train loss=-15216.044727, Validation loss=-14974.700000 \n",
      "Epoch 39: Train loss=-15276.320703, Validation loss=-15685.096875 \n",
      "Epoch 40: Train loss=-15940.900977, Validation loss=-16210.833594 \n",
      "Epoch 41: Train loss=-16365.045117, Validation loss=-16543.503125 \n",
      "Epoch 42: Train loss=-16668.891797, Validation loss=-16809.432031 \n",
      "Epoch 43: Train loss=-16918.175977, Validation loss=-17048.350000 \n",
      "Epoch 44: Train loss=-17151.599414, Validation loss=-17277.987500 \n",
      "Epoch 45: Train loss=-17359.006836, Validation loss=-17426.773438 \n",
      "Epoch 46: Train loss=-17480.681836, Validation loss=-17675.848438 \n",
      "Epoch 47: Train loss=-17729.441797, Validation loss=-17864.976562 \n",
      "Epoch 48: Train loss=-17983.000977, Validation loss=-18141.067969 \n",
      "Epoch 49: Train loss=-18279.315039, Validation loss=-18448.090625 \n",
      "Epoch 50: Train loss=-18515.317578, Validation loss=-18603.161719 \n",
      "Epoch 51: Train loss=-18673.992773, Validation loss=-18768.817188 \n",
      "Epoch 52: Train loss=-18847.361914, Validation loss=-18948.937500 \n",
      "Epoch 53: Train loss=-19031.376367, Validation loss=-19121.822656 \n",
      "Epoch 54: Train loss=-19213.182812, Validation loss=-19308.306250 \n",
      "Epoch 55: Train loss=-19343.429688, Validation loss=-14555.140625 \n",
      "Epoch 56: Train loss=-14494.941846, Validation loss=-16963.767188 \n",
      "Epoch 57: Train loss=-17191.218359, Validation loss=-17413.444531 \n",
      "Epoch 58: Train loss=-17533.781055, Validation loss=-17699.655469 \n",
      "Epoch 59: Train loss=-17780.635352, Validation loss=-17864.307812 \n",
      "Epoch 60: Train loss=-17936.851758, Validation loss=-18037.438281 \n",
      "Epoch 61: Train loss=-18116.052148, Validation loss=-18212.189063 \n",
      "Epoch 62: Train loss=-18282.067188, Validation loss=-18363.592969 \n",
      "Epoch 63: Train loss=-18425.985352, Validation loss=-18504.923438 \n",
      "Epoch 64: Train loss=-18567.347266, Validation loss=-18645.911719 \n",
      "Epoch 65: Train loss=-18709.398828, Validation loss=-18790.109375 \n",
      "Epoch 66: Train loss=-18854.905859, Validation loss=-18936.680469 \n",
      "Epoch 67: Train loss=-19003.017383, Validation loss=-19086.278125 \n",
      "Epoch 68: Train loss=-19154.325977, Validation loss=-19240.044531 \n",
      "Epoch 69: Train loss=-19310.198047, Validation loss=-19398.730469 \n",
      "Epoch 70: Train loss=-19470.470703, Validation loss=-19563.339062 \n",
      "Epoch 71: Train loss=-19636.060742, Validation loss=-19731.722656 \n",
      "Epoch 72: Train loss=-19806.570508, Validation loss=-19909.692969 \n",
      "Epoch 73: Train loss=-19983.140039, Validation loss=-20092.364844 \n",
      "Epoch 74: Train loss=-20160.022852, Validation loss=-20190.321094 \n",
      "Epoch 75: Train loss=-20092.487305, Validation loss=-20341.860938 \n",
      "Epoch 76: Train loss=-20296.786719, Validation loss=-20533.128125 \n",
      "Epoch 77: Train loss=-20559.415039, Validation loss=-20685.875000 \n",
      "Epoch 78: Train loss=-20743.531250, Validation loss=-20859.500000 (saved)\n",
      "Epoch 79: Train loss=-20877.848242, Validation loss=-20446.192188 \n",
      "Epoch 80: Train loss=-19825.828125, Validation loss=-20160.100000 \n",
      "Epoch 81: Train loss=-20548.514258, Validation loss=-20767.514062 \n",
      "Epoch 82: Train loss=-20770.478320, Validation loss=-20850.942188 \n",
      "Epoch 83: Train loss=-20997.899414, Validation loss=-21143.127344 (saved)\n",
      "Epoch 84: Train loss=-21223.758789, Validation loss=-21368.012500 (saved)\n",
      "Epoch 85: Train loss=-21434.113281, Validation loss=-21614.772656 (saved)\n",
      "Epoch 86: Train loss=-21400.981250, Validation loss=-21378.579688 \n",
      "Epoch 87: Train loss=-20990.982031, Validation loss=-21279.857031 \n",
      "Epoch 88: Train loss=-21550.066406, Validation loss=-21774.143750 (saved)\n",
      "Epoch 89: Train loss=-21844.704297, Validation loss=-21767.285156 \n",
      "Epoch 90: Train loss=-22018.555859, Validation loss=-22085.495312 (saved)\n",
      "Epoch 91: Train loss=-20760.982422, Validation loss=-13129.066406 \n",
      "Epoch 92: Train loss=-18155.760352, Validation loss=-20772.032813 \n",
      "Epoch 93: Train loss=-20131.074023, Validation loss=-20827.638281 \n",
      "Epoch 94: Train loss=-20500.394727, Validation loss=-20429.048438 \n",
      "Epoch 95: Train loss=-20559.506641, Validation loss=-20664.700000 \n",
      "Epoch 96: Train loss=-20758.876367, Validation loss=-20876.889062 \n",
      "Epoch 97: Train loss=-20970.171289, Validation loss=-21109.839062 \n",
      "Epoch 98: Train loss=-21204.828125, Validation loss=-21353.171094 \n",
      "Epoch 99: Train loss=-21465.266602, Validation loss=-21623.719531 \n",
      "Epoch 100: Train loss=-21686.401562, Validation loss=-21761.378906 \n",
      "Epoch 101: Train loss=-21827.701172, Validation loss=-21904.831250 \n",
      "Epoch 102: Train loss=-21970.165820, Validation loss=-22050.210938 \n",
      "Epoch 103: Train loss=-22105.669336, Validation loss=-22165.300000 (saved)\n",
      "Epoch 104: Train loss=-22232.776953, Validation loss=-22289.776562 (saved)\n",
      "Epoch 105: Train loss=-22242.112109, Validation loss=-22438.512500 (saved)\n",
      "Epoch 106: Train loss=-22423.901562, Validation loss=-22516.016406 (saved)\n",
      "Epoch 107: Train loss=-22478.393945, Validation loss=-22577.814063 (saved)\n",
      "Epoch 108: Train loss=-22103.915430, Validation loss=-22608.974219 (saved)\n",
      "Epoch 109: Train loss=-22558.345117, Validation loss=-22766.114844 (saved)\n",
      "Epoch 110: Train loss=-22659.841992, Validation loss=-22756.338281 \n",
      "Epoch 111: Train loss=-22480.245312, Validation loss=-22880.014062 (saved)\n",
      "Epoch 112: Train loss=-22192.490820, Validation loss=-22568.325781 \n",
      "Epoch 113: Train loss=-22714.270313, Validation loss=-22718.013281 \n",
      "Epoch 114: Train loss=-22709.331641, Validation loss=-22957.315625 (saved)\n",
      "Epoch 115: Train loss=-22724.033203, Validation loss=-21754.939063 \n",
      "Epoch 116: Train loss=-22007.416797, Validation loss=-22594.289062 \n",
      "Epoch 117: Train loss=-22122.333203, Validation loss=-22858.283594 \n",
      "Epoch 118: Train loss=-22463.307812, Validation loss=-22046.771875 \n",
      "Epoch 119: Train loss=-22242.024414, Validation loss=-22086.812500 \n",
      "Epoch 120: Train loss=-22324.697656, Validation loss=-22868.952344 \n",
      "Epoch 121: Train loss=-22756.254492, Validation loss=-22586.601562 \n",
      "Epoch 122: Train loss=-22565.922266, Validation loss=-22911.842969 \n",
      "Epoch 123: Train loss=-22776.887500, Validation loss=-22968.825781 (saved)\n",
      "Epoch 124: Train loss=-22778.652930, Validation loss=-23085.153125 (saved)\n",
      "Epoch 125: Train loss=-22561.893945, Validation loss=-23073.251563 \n",
      "Epoch 126: Train loss=-23034.171094, Validation loss=-22735.271875 \n",
      "Epoch 127: Train loss=-22706.673047, Validation loss=-23000.339844 \n",
      "Epoch 128: Train loss=-22800.560547, Validation loss=-23072.846875 \n",
      "Epoch 129: Train loss=-22991.650391, Validation loss=-23183.180469 (saved)\n",
      "Epoch 130: Train loss=-21322.055273, Validation loss=-22235.059375 \n",
      "Epoch 131: Train loss=-21719.304102, Validation loss=-22698.032813 \n",
      "Epoch 132: Train loss=-21925.130273, Validation loss=-20937.709375 \n",
      "Epoch 133: Train loss=-22169.463281, Validation loss=-22088.243750 \n",
      "Epoch 134: Train loss=-22428.353125, Validation loss=-22566.927344 \n",
      "Epoch 135: Train loss=-22703.815039, Validation loss=-22790.407031 \n",
      "Epoch 136: Train loss=-22800.351172, Validation loss=-22907.710156 \n",
      "Epoch 137: Train loss=-22929.382227, Validation loss=-22972.225781 \n",
      "Epoch 138: Train loss=-23024.763867, Validation loss=-23148.957813 \n",
      "Epoch 139: Train loss=-23139.986719, Validation loss=-23018.417187 \n",
      "Epoch 140: Train loss=-22554.527539, Validation loss=-20210.596094 \n",
      "Epoch 141: Train loss=-20348.089062, Validation loss=-22857.380469 \n",
      "Epoch 142: Train loss=-22089.244531, Validation loss=-22649.517188 \n",
      "Epoch 143: Train loss=-22370.541211, Validation loss=-22579.990625 \n",
      "Epoch 144: Train loss=-22561.254102, Validation loss=-22775.475000 \n",
      "Epoch 145: Train loss=-22748.439844, Validation loss=-22880.842187 \n",
      "Epoch 146: Train loss=-22897.722656, Validation loss=-22990.675781 \n",
      "Epoch 147: Train loss=-23042.576562, Validation loss=-23107.446094 \n",
      "Epoch 148: Train loss=-23150.852734, Validation loss=-23210.485156 (saved)\n",
      "Epoch 149: Train loss=-23252.718555, Validation loss=-23289.433594 (saved)\n",
      "Epoch 150: Train loss=-23310.899023, Validation loss=-23369.162500 (saved)\n",
      "Epoch 151: Train loss=-23334.010742, Validation loss=-23403.138281 (saved)\n",
      "Epoch 152: Train loss=-23399.294727, Validation loss=-23485.617969 (saved)\n",
      "Epoch 153: Train loss=-23288.936328, Validation loss=-23029.892969 \n",
      "Epoch 154: Train loss=-23251.796289, Validation loss=-23435.053125 \n",
      "Epoch 155: Train loss=-23029.846484, Validation loss=-22519.054688 \n",
      "Epoch 156: Train loss=-23120.156836, Validation loss=-23558.577344 (saved)\n",
      "Epoch 157: Train loss=-23475.760547, Validation loss=-22834.773438 \n",
      "Epoch 158: Train loss=-22373.387305, Validation loss=-23286.596875 \n",
      "Epoch 159: Train loss=-23004.437109, Validation loss=-22749.389844 \n",
      "Epoch 160: Train loss=-22979.494727, Validation loss=-23475.806250 \n",
      "Epoch 161: Train loss=-23153.431445, Validation loss=-22693.286719 \n",
      "Epoch 162: Train loss=-23143.334180, Validation loss=-23121.902344 \n",
      "Epoch 163: Train loss=-23409.770313, Validation loss=-23434.640625 \n",
      "Epoch 164: Train loss=-23495.522656, Validation loss=-23584.365625 (saved)\n",
      "Epoch 165: Train loss=-23482.948437, Validation loss=-23532.632031 \n",
      "Epoch 166: Train loss=-23203.886133, Validation loss=-22138.996875 \n",
      "Epoch 167: Train loss=-22766.272070, Validation loss=-23128.490625 \n",
      "Epoch 168: Train loss=-23390.700195, Validation loss=-23357.895313 \n",
      "Epoch 169: Train loss=-23421.346484, Validation loss=-23407.708594 \n",
      "Epoch 170: Train loss=-23574.534375, Validation loss=-23647.156250 (saved)\n",
      "Epoch 171: Train loss=-23647.188672, Validation loss=-23600.000000 \n",
      "Epoch 172: Train loss=-23550.525391, Validation loss=-23616.055469 \n",
      "Epoch 173: Train loss=-23332.577930, Validation loss=-23680.418750 (saved)\n",
      "Epoch 174: Train loss=-23583.167969, Validation loss=-23753.518750 (saved)\n",
      "Epoch 175: Train loss=-23648.020508, Validation loss=-23593.095313 \n",
      "Epoch 176: Train loss=-23552.606250, Validation loss=-23809.094531 (saved)\n",
      "Epoch 177: Train loss=-23680.740039, Validation loss=-22961.802344 \n",
      "Epoch 178: Train loss=-23167.250977, Validation loss=-23714.332031 \n",
      "Epoch 179: Train loss=-23451.492969, Validation loss=-22686.013281 \n",
      "Epoch 180: Train loss=-23004.681641, Validation loss=-23240.207813 \n",
      "Epoch 181: Train loss=-23306.222656, Validation loss=-23701.696875 \n",
      "Epoch 182: Train loss=-23673.685352, Validation loss=-23652.547656 \n",
      "Epoch 183: Train loss=-23539.425195, Validation loss=-23690.953906 \n",
      "Epoch 184: Train loss=-23614.193945, Validation loss=-23783.067188 \n",
      "Epoch 185: Train loss=-23785.964258, Validation loss=-23211.912500 \n",
      "Epoch 186: Train loss=-23548.498437, Validation loss=-23596.863281 \n",
      "Epoch 187: Train loss=-23366.183203, Validation loss=-23308.321094 \n",
      "Epoch 188: Train loss=-23148.499219, Validation loss=-22983.316406 \n",
      "Epoch 189: Train loss=-22325.575781, Validation loss=-23450.034375 \n",
      "Epoch 190: Train loss=-22788.506445, Validation loss=-22100.063281 \n",
      "Epoch 191: Train loss=-22937.196094, Validation loss=-23538.649219 \n",
      "Epoch 192: Train loss=-23069.254492, Validation loss=-23665.531250 \n",
      "Epoch 193: Train loss=-23220.776562, Validation loss=-23553.864844 \n",
      "Epoch 194: Train loss=-23581.741211, Validation loss=-23652.747656 \n",
      "Epoch 195: Train loss=-23633.379883, Validation loss=-23513.810156 \n",
      "Epoch 196: Train loss=-23446.943945, Validation loss=-23663.014062 \n",
      "Epoch 197: Train loss=-23692.530859, Validation loss=-23786.808594 \n",
      "Epoch 198: Train loss=-23560.392188, Validation loss=-23177.515625 \n",
      "Epoch 199: Train loss=-23066.469922, Validation loss=-22142.592187 \n",
      "Finished training with best train loss: -23760.304688 and validation loss: -23809.094531\n"
     ]
    }
   ],
   "source": [
    "# x_ood_test = x_ood_test.to(device)\n",
    "start = time.time()\n",
    "model.fit(train_loader, valid_loader, x_test=x_ood_test, epochs=epochs, lr=lr, step_size=step_size, gamma=gamma, tpred = torch.tensor(tpred).to(device), dataset_class = dataset_class, t=t.to(device), grid_train=grid.to(device))\n",
    "stop = time.time()\n",
    "# print(stop-start)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 137,
   "id": "fea6bf3c-dd2f-4877-b902-345a59828a3f",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save(model.state_dict(), \"./pme_e2e_var_update_ortho_nll_2_3_05_15.pt\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 138,
   "id": "221859cb-662e-4e38-ad03-cd466b1c3e74",
   "metadata": {},
   "outputs": [],
   "source": [
    "#model.load_state_dict(torch.load(\"./pme_e2e_var_update_ortho_crps_05_05_4_5.pt\", weights_only=True))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 139,
   "id": "e82f2228-6888-4796-8b40-f13e419d3ff7",
   "metadata": {},
   "outputs": [],
   "source": [
    "from nonlinear_projection import project_and_stats, project_and_stats_orth\n",
    "\n",
    "def test(model, test_loader, **test_params):\n",
    "    test_type = test_params.get(\"test_type\", \"id\")\n",
    "    mu = []\n",
    "    var = []\n",
    "    results = {}\n",
    "    results[\"loss\"] = 0.0\n",
    "\n",
    "    model = model.to(device)\n",
    "\n",
    "    with torch.no_grad():\n",
    "        for batch_idx, batch in enumerate(test_loader):\n",
    "            x, y = batch\n",
    "            x, y = x.to(device), y.to(device)\n",
    "\n",
    "            out = model(x)\n",
    "\n",
    "            _mu, _var = out\n",
    "            _std = torch.sqrt(_var)\n",
    "\n",
    "            nf,nx,nt,_ = _mu.shape\n",
    "\n",
    "            _mu = _mu.view(nf, -1)\n",
    "            _var = _var.view(nf, -1)\n",
    "            _m = x.view(nf, -1)\n",
    "\n",
    "            if model.noneq_constraint_e2e:\n",
    "                u_proj, u_var = project_and_stats_orth(torch.relu(_mu), _var, _m, model.full_residual, max_iter=30)\n",
    "                # u_proj, u_var = project_and_stats_orth(torch.relu(_mu), _var, _m, self.full_residual, max_iter=30)\n",
    "\n",
    "                out = (u_proj.view(nf,nx,nt,1), u_var.view(nf,nx,nt,1))\n",
    "\n",
    "\n",
    "            # out = model.base_model._apply_constraints(_mu, _std, x, t, tpred, grid, dataset_class)\n",
    "\n",
    "\n",
    "            # nf,nx,nt,_ = _mu.shape\n",
    "\n",
    "            # _mu = _mu.view(nf, -1)\n",
    "            # _var = _var.view(nf, -1)\n",
    "            # _m = x.view(nf, -1)\n",
    "\n",
    "            # # print(_m)\n",
    "\n",
    "            # u_proj, u_var = project_and_stats(torch.relu(_mu), _var, _m, model.full_residual, max_iter=30)\n",
    "\n",
    "            # # print(u_proj, u_var)\n",
    "\n",
    "            # if  u_proj.isnan().any().item() or  u_var.isnan().any().item():\n",
    "            #     print(\"any NaN in new_mu?\", u_proj.isnan().any().item())\n",
    "            #     # print(\"min new_var before clamp:\", u_var.min().item())\n",
    "            #     print(\"any NaN in new_var before clamp?\", u_var.isnan().any().item())\n",
    "            #     # new_var = new_var.clamp(min=eps)\n",
    "            #     # print(\"min new_var after clamp:\", new_var.min().item())\n",
    "\n",
    "            # out = (u_proj.view(nf,nx,nt,1), u_var.view(nf,nx,nt,1))\n",
    "\n",
    "            # if model.probconserv:\n",
    "            #     _mu, _var = out\n",
    "            #     _std = torch.sqrt(_var)\n",
    "            #     mass_rhs_func = dataset_class.get_mass_rhs_func(x=x)\n",
    "            #     new_mu, new_std, _, mass_rhs = probconserv.apply_constraint(\n",
    "            #                                                     mu=_mu[:, :, :, 0], \n",
    "            #                                                     std=_std[:, :, :, 0], \n",
    "            #                                                     mass_rhs_func=mass_rhs_func, \n",
    "            #                                                     t=t, \n",
    "            #                                                     tpred=tpred, \n",
    "            #                                                     grid_train=grid, \n",
    "            #                                                     precis_g=np.inf,\n",
    "            #                                                     second_deriv_alpha=None,\n",
    "            #                                                     )\n",
    "            #     out = (new_mu.unsqueeze(-1), torch.square(new_std).unsqueeze(-1))\n",
    "\n",
    "            results[\"loss\"] += model.loss_func(out, y).item()\n",
    "            utils.compute_all_metrics(out, y, results)\n",
    "\n",
    "            if uq:\n",
    "                mu.append(out[0].detach().cpu())\n",
    "                var.append(out[1].detach().cpu())\n",
    "            else:\n",
    "                mu.append(out.detach().cpu())\n",
    "\n",
    "    # print(results['mse'])\n",
    "    # print(len(test_loader.dataset))\n",
    "\n",
    "    for key in results.keys():\n",
    "        if not key.endswith(\"by_example\"):\n",
    "            results[key] /= len(test_loader.dataset)\n",
    "        if type(results[key]) == torch.Tensor:\n",
    "            results[key] = results[key].tolist()\n",
    "\n",
    "    # Plot\n",
    "    mu = torch.cat(mu, dim=0)\n",
    "    if uq:\n",
    "        var = torch.cat(var, dim=0)\n",
    "        std = torch.sqrt(var)\n",
    "    else:\n",
    "        var = None\n",
    "        std = None\n",
    "    x = test_loader.dataset.tensors[0]\n",
    "    y = test_loader.dataset.tensors[1]\n",
    "\n",
    "    if uq:\n",
    "        results[\"nMeRCI_all\"] = utils.compute_nMeRCI(mu, var, y).item()\n",
    "        results[\"rmsce_all\"] = utils.compute_rmsce(mu, var, y).item()\n",
    "\n",
    "        if is_probconserv:\n",
    "            print(\"Here\")\n",
    "            mass_rhs_func = dataset_class.get_mass_rhs_func(x=x)\n",
    "            new_mu, new_std, _, mass_rhs = probconserv.apply_constraint(\n",
    "                mu=mu[:, :, :, 0], \n",
    "                std=std[:, :, :, 0], \n",
    "                mass_rhs_func=mass_rhs_func, \n",
    "                t=t, \n",
    "                tpred=tpred, \n",
    "                grid_train=grid, \n",
    "                precis_g=np.inf,\n",
    "                second_deriv_alpha=None,\n",
    "            )\n",
    "            new_mu = new_mu[:, :, :, None]\n",
    "            new_std = new_std[:, :, :, None]\n",
    "            new_var = new_std**2\n",
    "\n",
    "            probconserv_results = utils.compute_all_metrics((new_mu, new_var), y, {})\n",
    "            for key in probconserv_results.keys():\n",
    "                if not key.endswith(\"by_example\"):\n",
    "                    probconserv_results[key] /= len(test_loader.dataset)\n",
    "                if type(probconserv_results[key]) == torch.Tensor:\n",
    "                    probconserv_results[key] = probconserv_results[key].tolist()\n",
    "\n",
    "            probconserv_results[\"nMeRCI_all\"] = utils.compute_nMeRCI(new_mu, new_var, y).item()\n",
    "            probconserv_results[\"rmsce_all\"] = utils.compute_rmsce(new_mu, new_var, y).item()\n",
    "\n",
    "            cerr = (probconserv.get_empirical_mass_rhs(mu[:, :,  :, 0]) - mass_rhs).abs().sum(dim=-1)\n",
    "            new_cerr = (probconserv.get_empirical_mass_rhs(new_mu[:, :, :, 0]) - mass_rhs).abs().sum(dim=-1)\n",
    "\n",
    "            results[\"cerr_by_example\"] = cerr.tolist()\n",
    "            results[\"mcerr\"] = cerr.mean().item()\n",
    "            probconserv_results[\"cerr_by_example\"] = new_cerr.tolist()\n",
    "            probconserv_results[\"mcerr\"] = new_cerr.mean().item()\n",
    "\n",
    "            for key in probconserv_results.keys():\n",
    "                results[f\"pc.{key}\"] = probconserv_results[key]\n",
    "    \n",
    "    # results[\"time\"] = utils.compute_forward_time(model, x[:batch_size].to(device), repetitions=10)\n",
    "    results[\"n_params\"] = utils.compute_n_params(model)\n",
    "    results[\"n_flops\"] = utils.compute_n_flops(model_name, Np=n_x*n_t, fno_modes=fno_modes, fno_width=fno_width, n_layers=4, n_models=n_models)\n",
    "\n",
    "    dataset_params_correct_type = dataset_params if test_type == \"id\" or test_type == \"train\" else ood_dataset_params\n",
    "\n",
    "    mse_by_example = torch.tensor(results[\"mse_by_example\"])\n",
    "    random_idx = np.random.choice(mse_by_example.shape[0])\n",
    "    _, worst_idx = mse_by_example.max(dim=0)\n",
    "    _, best_idx = mse_by_example.min(dim=0)\n",
    "    _, median_idx = mse_by_example.median(dim=0)\n",
    "\n",
    "    for example_name, example_idx in zip([\"random\", \"worst\", \"best\", \"median\"], [random_idx, worst_idx, best_idx, median_idx]):\n",
    "        if uq:\n",
    "            results[f\"examples.{example_name}\"] = (mu[example_idx].tolist(), var[example_idx].tolist(), y[example_idx].tolist(), x[example_idx].tolist())\n",
    "            if is_probconserv:\n",
    "                results[f\"pc.examples.{example_name}\"] = (new_mu[example_idx].tolist(), new_var[example_idx].tolist(), y[example_idx].tolist(), x[example_idx].tolist())\n",
    "        else:\n",
    "            results[f\"examples.{example_name}\"] = (mu[example_idx].tolist(), None, y[example_idx].tolist(), x[example_idx].tolist())\n",
    "\n",
    "        # prefix = f\"{test_type}_{example_name}_params={dataset_params_correct_type}\"\n",
    "        # plot_and_save(prefix, example_idx, x.squeeze(-1), y.squeeze(-1), mu.squeeze(-1), std.squeeze(-1) if std is not None else None)\n",
    "\n",
    "    # utils.dict_to_file({\"test_type\": test_type, \"params\": dataset_params_correct_type, \"results\": results}, \n",
    "    #                    f\"{run_folder}/results_{test_type}_params={dataset_params_correct_type}.json\")\n",
    "\n",
    "    return results\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 140,
   "id": "4f58dc93-0b91-4195-b51e-126ad1212b6f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Here\n",
      "Here\n",
      "Here\n",
      "Here\n",
      "In-domain results\n",
      "MSE: 3.73231217963621e-05\n",
      "n-MeRCI: 0.4797120690345764\n",
      "RMSCE: 0.2042180746793747\n",
      "ProbConserv Results\n",
      "MSE: 1.6498995246365667e-05\n",
      "n-MeRCI: 0.5785845518112183\n",
      "RMSCE: 0.23048627376556396\n",
      "Cerr: 0.03420382738113403\n",
      "Prob_Cerr: 2.9866583872717456e-07\n",
      "Here\n",
      "\n",
      "\n",
      "Out-of-domain results\n",
      "MSE: 4.3509078095667065e-05\n",
      "n-MeRCI: 0.507407546043396\n",
      "RMSCE: 0.20524950325489044\n",
      "ProbConserv Results\n",
      "MSE: 1.8751994939520955e-05\n",
      "n-MeRCI: 0.643882691860199\n",
      "RMSCE: 0.22795435786247253\n",
      "Cerr: 0.03335271030664444\n",
      "Prob_Cerr: 2.7469360475151916e-07\n"
     ]
    }
   ],
   "source": [
    "is_probconserv = True\n",
    "\n",
    "train_loader_no_shuffle = torch.utils.data.DataLoader(train_loader.dataset, batch_size=batch_size, shuffle=False)\n",
    "train_results = test(model, train_loader_no_shuffle, test_type=\"train\")\n",
    "id_results = test(model, id_test_loader, test_type=\"id\")\n",
    "\n",
    "if is_train:\n",
    "    train_loader_no_shuffle = torch.utils.data.DataLoader(train_loader.dataset, batch_size=batch_size, shuffle=False)\n",
    "    train_results = test(model, train_loader_no_shuffle, test_type=\"train\")\n",
    "    id_results = test(model, id_test_loader, test_type=\"id\")\n",
    "\n",
    "    print(\"In-domain results\")\n",
    "    print(f\"MSE: {id_results['mse']}\")\n",
    "    print(f\"n-MeRCI: {id_results['nMeRCI_all']}\")\n",
    "    print(f\"RMSCE: {id_results['rmsce_all']}\")\n",
    "\n",
    "    if is_probconserv:\n",
    "        print(\"ProbConserv Results\")\n",
    "        print(f\"MSE: {id_results['pc.mse']}\")\n",
    "        print(f\"n-MeRCI: {id_results['pc.nMeRCI_all']}\")\n",
    "        print(f\"RMSCE: {id_results['pc.rmsce_all']}\")\n",
    "        print(f\"Cerr: {id_results['mcerr']}\")\n",
    "        print(f\"Prob_Cerr: {id_results['pc.mcerr']}\")\n",
    "        \n",
    "\n",
    "ood_results = test(model, ood_test_loader, test_type=\"ood\")\n",
    "\n",
    "print(\"\\n\")\n",
    "print(\"Out-of-domain results\")\n",
    "print(f\"MSE: {ood_results['mse']}\")\n",
    "print(f\"n-MeRCI: {ood_results['nMeRCI_all']}\")\n",
    "print(f\"RMSCE: {ood_results['rmsce_all']}\")\n",
    "\n",
    "if is_probconserv:\n",
    "    print(\"ProbConserv Results\")\n",
    "    print(f\"MSE: {ood_results['pc.mse']}\")\n",
    "    print(f\"n-MeRCI: {ood_results['pc.nMeRCI_all']}\")\n",
    "    print(f\"RMSCE: {ood_results['pc.rmsce_all']}\")\n",
    "    print(f\"Cerr: {ood_results['mcerr']}\")\n",
    "    print(f\"Prob_Cerr: {ood_results['pc.mcerr']}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 141,
   "id": "d3c74561-cafe-4068-b5ec-92d2c4634cc9",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "3.73231217963621e-05"
      ]
     },
     "execution_count": 141,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "id_results['mse']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 142,
   "id": "c29252c6-3147-4b6a-b320-ee566b03e910",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "False"
      ]
     },
     "execution_count": 142,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.probconserv"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 143,
   "id": "b13f9495-85ed-4ba9-91fd-301557685c8a",
   "metadata": {},
   "outputs": [],
   "source": [
    "out = model(x_ood_test.to(device))\n",
    "x = ood_test_loader.dataset.tensors[0]\n",
    "y = ood_test_loader.dataset.tensors[1]\n",
    "mass_rhs_func = dataset_class.get_mass_rhs_func(x=x)\n",
    "\n",
    "if model.probconserv:\n",
    "    _mu, _var, = out[0].cpu(), out[1].cpu()\n",
    "    _std = torch.sqrt(_var)\n",
    "    mass_rhs_func = dataset_class.get_mass_rhs_func(x=x)\n",
    "    new_mu, new_std, _, mass_rhs = probconserv.apply_constraint(\n",
    "                                                    mu=_mu[:, :, :, 0], \n",
    "                                                    std=_std[:, :, :, 0], \n",
    "                                                    mass_rhs_func=mass_rhs_func, \n",
    "                                                    t=t, \n",
    "                                                    tpred=tpred, \n",
    "                                                    grid_train=grid, \n",
    "                                                    precis_g=np.inf,\n",
    "                                                    second_deriv_alpha=None,\n",
    "                                                    )\n",
    "    out = (new_mu.unsqueeze(-1), torch.square(new_std).unsqueeze(-1))\n",
    "\n",
    "mu, var = out\n",
    "nf,nx,nt,_ = mu.shape\n",
    "\n",
    "_mu = mu.view(nf, -1)\n",
    "_var = var.view(nf, -1)\n",
    "_m = x.view(nf, -1).to(device)\n",
    "\n",
    "# print(_m)\n",
    "\n",
    "u_proj, u_var = project_and_stats_orth(torch.relu(_mu), _var, _m, model.full_residual, max_iter=30)\n",
    "\n",
    "out = (u_proj.view(nf,nx,nt,1), u_var .view(nf,nx,nt,1))\n",
    "\n",
    "mu, var, = out[0].cpu(), out[1].cpu()\n",
    "\n",
    "std = torch.sqrt(var)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 144,
   "id": "58602c51-77ce-4e25-ac60-b976ccd541c0",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[[1.0000e-06],\n",
       "          [1.0000e-06],\n",
       "          [1.0000e-06],\n",
       "          ...,\n",
       "          [1.0000e-06],\n",
       "          [1.0000e-06],\n",
       "          [1.0000e-06]],\n",
       "\n",
       "         [[1.0000e-06],\n",
       "          [4.9121e-05],\n",
       "          [1.2157e-05],\n",
       "          ...,\n",
       "          [1.0000e-06],\n",
       "          [1.0000e-06],\n",
       "          [2.3136e-04]],\n",
       "\n",
       "         [[1.0000e-06],\n",
       "          [2.4280e-05],\n",
       "          [2.7961e-05],\n",
       "          ...,\n",
       "          [8.9588e-05],\n",
       "          [7.5112e-05],\n",
       "          [1.0950e-04]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[1.0000e-06],\n",
       "          [1.0118e-06],\n",
       "          [1.0271e-06],\n",
       "          ...,\n",
       "          [1.0003e-06],\n",
       "          [1.0003e-06],\n",
       "          [1.0001e-06]],\n",
       "\n",
       "         [[1.0000e-06],\n",
       "          [1.0118e-06],\n",
       "          [1.0271e-06],\n",
       "          ...,\n",
       "          [1.0003e-06],\n",
       "          [1.0003e-06],\n",
       "          [1.0001e-06]],\n",
       "\n",
       "         [[1.0000e-06],\n",
       "          [1.0000e-06],\n",
       "          [1.0000e-06],\n",
       "          ...,\n",
       "          [1.0000e-06],\n",
       "          [1.0000e-06],\n",
       "          [1.0000e-06]]],\n",
       "\n",
       "\n",
       "        [[[1.0000e-06],\n",
       "          [1.0000e-06],\n",
       "          [1.0000e-06],\n",
       "          ...,\n",
       "          [1.0000e-06],\n",
       "          [1.0000e-06],\n",
       "          [1.0000e-06]],\n",
       "\n",
       "         [[1.0000e-06],\n",
       "          [4.4400e-05],\n",
       "          [1.1006e-05],\n",
       "          ...,\n",
       "          [1.0000e-06],\n",
       "          [1.0000e-06],\n",
       "          [1.4997e-04]],\n",
       "\n",
       "         [[1.0000e-06],\n",
       "          [2.1644e-05],\n",
       "          [2.0844e-05],\n",
       "          ...,\n",
       "          [5.7688e-05],\n",
       "          [4.7838e-05],\n",
       "          [6.9782e-05]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[1.0000e-06],\n",
       "          [1.0095e-06],\n",
       "          [1.0183e-06],\n",
       "          ...,\n",
       "          [1.0002e-06],\n",
       "          [1.0002e-06],\n",
       "          [1.0001e-06]],\n",
       "\n",
       "         [[1.0000e-06],\n",
       "          [1.0095e-06],\n",
       "          [1.0183e-06],\n",
       "          ...,\n",
       "          [1.0002e-06],\n",
       "          [1.0002e-06],\n",
       "          [1.0001e-06]],\n",
       "\n",
       "         [[1.0000e-06],\n",
       "          [1.0000e-06],\n",
       "          [1.0000e-06],\n",
       "          ...,\n",
       "          [1.0000e-06],\n",
       "          [1.0000e-06],\n",
       "          [1.0000e-06]]],\n",
       "\n",
       "\n",
       "        [[[1.0000e-06],\n",
       "          [1.0000e-06],\n",
       "          [1.0000e-06],\n",
       "          ...,\n",
       "          [1.0000e-06],\n",
       "          [1.0000e-06],\n",
       "          [1.0000e-06]],\n",
       "\n",
       "         [[1.0000e-06],\n",
       "          [4.9318e-05],\n",
       "          [1.2205e-05],\n",
       "          ...,\n",
       "          [1.0000e-06],\n",
       "          [1.0000e-06],\n",
       "          [2.3740e-04]],\n",
       "\n",
       "         [[1.0000e-06],\n",
       "          [2.4425e-05],\n",
       "          [2.8416e-05],\n",
       "          ...,\n",
       "          [9.1994e-05],\n",
       "          [7.7168e-05],\n",
       "          [1.1246e-04]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[1.0000e-06],\n",
       "          [1.0119e-06],\n",
       "          [1.0277e-06],\n",
       "          ...,\n",
       "          [1.0003e-06],\n",
       "          [1.0003e-06],\n",
       "          [1.0002e-06]],\n",
       "\n",
       "         [[1.0000e-06],\n",
       "          [1.0119e-06],\n",
       "          [1.0277e-06],\n",
       "          ...,\n",
       "          [1.0003e-06],\n",
       "          [1.0003e-06],\n",
       "          [1.0002e-06]],\n",
       "\n",
       "         [[1.0000e-06],\n",
       "          [1.0000e-06],\n",
       "          [1.0000e-06],\n",
       "          ...,\n",
       "          [1.0000e-06],\n",
       "          [1.0000e-06],\n",
       "          [1.0000e-06]]],\n",
       "\n",
       "\n",
       "        ...,\n",
       "\n",
       "\n",
       "        [[[1.0000e-06],\n",
       "          [1.0000e-06],\n",
       "          [1.0000e-06],\n",
       "          ...,\n",
       "          [1.0000e-06],\n",
       "          [1.0000e-06],\n",
       "          [1.0000e-06]],\n",
       "\n",
       "         [[1.0000e-06],\n",
       "          [4.6744e-05],\n",
       "          [1.1579e-05],\n",
       "          ...,\n",
       "          [1.0000e-06],\n",
       "          [1.0000e-06],\n",
       "          [1.8072e-04]],\n",
       "\n",
       "         [[1.0000e-06],\n",
       "          [2.2846e-05],\n",
       "          [2.3789e-05],\n",
       "          ...,\n",
       "          [6.9622e-05],\n",
       "          [5.8044e-05],\n",
       "          [8.4747e-05]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[1.0000e-06],\n",
       "          [1.0105e-06],\n",
       "          [1.0218e-06],\n",
       "          ...,\n",
       "          [1.0002e-06],\n",
       "          [1.0002e-06],\n",
       "          [1.0001e-06]],\n",
       "\n",
       "         [[1.0000e-06],\n",
       "          [1.0105e-06],\n",
       "          [1.0218e-06],\n",
       "          ...,\n",
       "          [1.0002e-06],\n",
       "          [1.0002e-06],\n",
       "          [1.0001e-06]],\n",
       "\n",
       "         [[1.0000e-06],\n",
       "          [1.0000e-06],\n",
       "          [1.0000e-06],\n",
       "          ...,\n",
       "          [1.0000e-06],\n",
       "          [1.0000e-06],\n",
       "          [1.0000e-06]]],\n",
       "\n",
       "\n",
       "        [[[1.0000e-06],\n",
       "          [1.0000e-06],\n",
       "          [1.0000e-06],\n",
       "          ...,\n",
       "          [1.0000e-06],\n",
       "          [1.0000e-06],\n",
       "          [1.0000e-06]],\n",
       "\n",
       "         [[1.0000e-06],\n",
       "          [4.9615e-05],\n",
       "          [1.2277e-05],\n",
       "          ...,\n",
       "          [1.0000e-06],\n",
       "          [1.0000e-06],\n",
       "          [2.4751e-04]],\n",
       "\n",
       "         [[1.0000e-06],\n",
       "          [2.4659e-05],\n",
       "          [2.9160e-05],\n",
       "          ...,\n",
       "          [9.6034e-05],\n",
       "          [8.0618e-05],\n",
       "          [1.1742e-04]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[1.0000e-06],\n",
       "          [1.0121e-06],\n",
       "          [1.0287e-06],\n",
       "          ...,\n",
       "          [1.0004e-06],\n",
       "          [1.0004e-06],\n",
       "          [1.0002e-06]],\n",
       "\n",
       "         [[1.0000e-06],\n",
       "          [1.0121e-06],\n",
       "          [1.0287e-06],\n",
       "          ...,\n",
       "          [1.0004e-06],\n",
       "          [1.0004e-06],\n",
       "          [1.0002e-06]],\n",
       "\n",
       "         [[1.0000e-06],\n",
       "          [1.0000e-06],\n",
       "          [1.0000e-06],\n",
       "          ...,\n",
       "          [1.0000e-06],\n",
       "          [1.0000e-06],\n",
       "          [1.0000e-06]]],\n",
       "\n",
       "\n",
       "        [[[1.0000e-06],\n",
       "          [1.0000e-06],\n",
       "          [1.0000e-06],\n",
       "          ...,\n",
       "          [1.0000e-06],\n",
       "          [1.0000e-06],\n",
       "          [1.0000e-06]],\n",
       "\n",
       "         [[1.0000e-06],\n",
       "          [3.8940e-05],\n",
       "          [9.6654e-06],\n",
       "          ...,\n",
       "          [1.0000e-06],\n",
       "          [1.0000e-06],\n",
       "          [1.0462e-04]],\n",
       "\n",
       "         [[1.0000e-06],\n",
       "          [1.8916e-05],\n",
       "          [1.5682e-05],\n",
       "          ...,\n",
       "          [4.0315e-05],\n",
       "          [3.3002e-05],\n",
       "          [4.7815e-05]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[1.0000e-06],\n",
       "          [1.0070e-06],\n",
       "          [1.0128e-06],\n",
       "          ...,\n",
       "          [1.0001e-06],\n",
       "          [1.0001e-06],\n",
       "          [1.0000e-06]],\n",
       "\n",
       "         [[1.0000e-06],\n",
       "          [1.0070e-06],\n",
       "          [1.0128e-06],\n",
       "          ...,\n",
       "          [1.0001e-06],\n",
       "          [1.0001e-06],\n",
       "          [1.0000e-06]],\n",
       "\n",
       "         [[1.0000e-06],\n",
       "          [1.0000e-06],\n",
       "          [1.0000e-06],\n",
       "          ...,\n",
       "          [1.0000e-06],\n",
       "          [1.0000e-06],\n",
       "          [1.0000e-06]]]], grad_fn=<ToCopyBackward0>)"
      ]
     },
     "execution_count": 144,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "var"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 145,
   "id": "6d7837ba-f720-43d2-b8d0-8252c41e8ad3",
   "metadata": {},
   "outputs": [],
   "source": [
    "std = torch.sqrt(var)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 146,
   "id": "cb722479-afdf-4b05-a156-c8a2bbbf245a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAkIAAAHHCAYAAABTMjf2AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAABpL0lEQVR4nO3dd3hT5dsH8G+SJuneuxTassu2QC1bqJYhCg5QEQoiiIIKFRREQEApArKRKrJUEBCVny8gKAUUsOwhG4GW3UKB7p2c94/aSGjaJmnSkJzv57pyXc2T55xz5yRN7jzrSARBEEBEREQkQlJLB0BERERkKUyEiIiISLSYCBEREZFoMREiIiIi0WIiRERERKLFRIiIiIhEi4kQERERiRYTISIiIhItJkJEREQkWkyEyKaEhIRg8ODBlg6DasCqVasgkUiQkpJS48cuKSnB+++/j+DgYEilUvTp06fGYyAi02AiROWUfcEcPnzY0qFYFYlEglGjRul8rCbO6c2bN/Hxxx/j+PHjetUvi6mi2/79+80WqyFmzJiBTZs2WToMLStWrMDs2bPxwgsvYPXq1RgzZoylQ7Jpa9euxfz58y0dhtF++ukn9O/fH2FhYXB0dETDhg3x3nvvISMjo8pt1Wo1Vq1ahWeeeQbBwcFwcnJC06ZN8cknn6CgoEDnNmlpaXjjjTcQFBQEe3t7hISEYOjQoSZ+VrbDztIBEJnS+fPnIZWKM7+/efMmpk6dipCQELRs2VLv7aZNm4bQ0NBy5fXq1TNhdMabMWMGXnjhhXKtLgMHDsRLL70EpVJZ4zHt3LkTQUFBmDdvXo0fW4zWrl2LU6dOYfTo0ZYOxSjDhw9HYGAgXn31VdSuXRsnT57E4sWLsXXrVhw9ehQODg4VbpuXl4chQ4bg8ccfx4gRI+Dr64ukpCRMmTIFiYmJ2LlzJyQSiab+tWvX0L59ewDAiBEjEBQUhJs3b+LgwYNmf57WiokQPbJKSkqgVquhUCj03sYSX4rWrkePHmjdurWlwzCYTCaDTCazyLFv374Nd3d3k+1PrVajqKgI9vb2JtunoQoKCqBQKETzQ6Imz/nGjRvRpUsXrbKIiAjExsZizZo1eP311yvcVqFQYN++fWjXrp2mbNiwYQgJCdEkQ9HR0ZrH3njjDdjZ2eHQoUPw8vIy+XOxReJ4x5NZ3LhxA6+99hr8/PygVCrRpEkTrFixQqtOUVERJk+ejIiICLi5ucHJyQkdO3bErl27tOqlpKRAIpFgzpw5mD9/PurWrQulUokzZ87g448/hkQiwcWLFzF48GC4u7vDzc0NQ4YMQV5entZ+Hh4jVNb9s2/fPsTFxcHHxwdOTk7o27cv7ty5o7WtWq3Gxx9/jMDAQDg6OuKJJ57AmTNnzDru6Ny5c3jhhRfg6ekJe3t7tG7dGr/88otWnXv37mHs2LFo1qwZnJ2d4erqih49euDEiROaOrt370abNm0AAEOGDNF0b61atcokcWZkZGDw4MFwc3ODu7s7YmNjcfz48XLH6NKlS7kPfAAYPHgwQkJCtMrmzJmDdu3awcvLCw4ODoiIiMDGjRu16kgkEuTm5mL16tWa51T2WlQ0RuiLL75AkyZNoFQqERgYiJEjR5brgujSpQuaNm2KM2fO4IknnoCjoyOCgoIwa9asSs9D2ft0165dOH36tCam3bt3AwByc3Px3nvvITg4GEqlEg0bNsScOXMgCEK55zVq1CisWbNGE+u2bdsqPG5ISAiefvpp/Pbbb2jZsiXs7e0RHh6On376SauePu8VoPT9IpFIsG7dOnz00UcICgqCo6MjsrKyDN7Hhg0bMHXqVAQFBcHFxQUvvPACMjMzUVhYiNGjR8PX1xfOzs4YMmQICgsLyz237777DhEREXBwcICnpydeeuklXLt2Teu12rJlC65cuaI53w++lwoLCzFlyhTUq1cPSqUSwcHBeP/998sdy9Bzbkq6/if69u0LADh79myl2yoUCq0kqLLtz507h19//RXjxo2Dl5cXCgoKUFxcXI3IxYEtQmSUtLQ0PP7445oPFx8fH/z6668YOnQosrKyNE3YWVlZ+Prrr/Hyyy9j2LBhyM7OxvLlyxETE4ODBw+W68JZuXIlCgoKMHz4cCiVSnh6emoe69evH0JDQxEfH4+jR4/i66+/hq+vLz777LMq43377bfh4eGBKVOmICUlBfPnz8eoUaOwfv16TZ0JEyZg1qxZ6N27N2JiYnDixAnExMRU2A+vS0FBAdLT08uV5+TklCs7ffo02rdvj6CgIIwfPx5OTk7YsGED+vTpgx9//FHzQXf58mVs2rQJL774IkJDQ5GWloYvv/wSnTt3xpkzZxAYGIjGjRtj2rRpmDx5MoYPH46OHTsCgM4P0IdlZmaWi1kikWh+TQqCgGeffRZ79+7FiBEj0LhxY/z888+IjY3V+7zosmDBAjzzzDMYMGAAioqKsG7dOrz44ovYvHkzevXqBQD49ttv8frrr6Nt27YYPnw4AKBu3boV7vPjjz/G1KlTER0djTfffBPnz5/H0qVLcejQIezbtw9yuVxT9/79++jevTuee+459OvXDxs3bsQHH3yAZs2aoUePHjr37+Pjg2+//RaffvopcnJyEB8fDwBo3LgxBEHAM888g127dmHo0KFo2bIltm/fjnHjxuHGjRvlutF27tyJDRs2YNSoUfD29i6XKD7sn3/+Qf/+/TFixAjExsZi5cqVePHFF7Ft2zY8+eSTAPR7rzxo+vTpUCgUGDt2LAoLC6FQKHDmzBmD9hEfHw8HBweMHz8eFy9exKJFiyCXyyGVSnH//n18/PHH2L9/P1atWoXQ0FBMnjxZs+2nn36KSZMmoV+/fnj99ddx584dLFq0CJ06dcKxY8fg7u6OiRMnIjMzE9evX9ecQ2dnZwClP16eeeYZ7N27F8OHD0fjxo1x8uRJzJs3DxcuXCg3tsyQc56Tk6PX/75cLoebm1uV9R6WmpoKAPD29jZ424q237FjBwDAz88P3bp1w86dOyGTyfDkk09i6dKlVb7HREsgesjKlSsFAMKhQ4cqrDN06FAhICBASE9P1yp/6aWXBDc3NyEvL08QBEEoKSkRCgsLtercv39f8PPzE1577TVNWXJysgBAcHV1FW7fvq1Vf8qUKQIArfqCIAh9+/YVvLy8tMrq1KkjxMbGlnsu0dHRglqt1pSPGTNGkMlkQkZGhiAIgpCamirY2dkJffr00drfxx9/LADQ2mdFAFR5e/CcduvWTWjWrJlQUFCgKVOr1UK7du2E+vXra8oKCgoElUqldazk5GRBqVQK06ZN05QdOnRIACCsXLmyylgfPDe6bkqlUlNv06ZNAgBh1qxZmrKSkhKhY8eO5Y7XuXNnoXPnzuWOFRsbK9SpU0errOw9UqaoqEho2rSp0LVrV61yJycnnee/LP7k5GRBEATh9u3bgkKhEJ566imt87V48WIBgLBixQqtOAEI33zzjaassLBQ8Pf3F55//vlyx3pY586dhSZNmmiVlZ2nTz75RKv8hRdeECQSiXDx4kVNGQBBKpUKp0+frvJYglD6vgYg/Pjjj5qyzMxMISAgQGjVqpWmTN/3yq5duwQAQlhYWLnXwdB9NG3aVCgqKtKUv/zyy4JEIhF69OihtY+oqCit90BKSoogk8mETz/9VKveyZMnBTs7O63yXr16lXv/CIIgfPvtt4JUKhX27NmjVZ6QkCAAEPbt26cpM/Scx8bG6vU/rev9ro+hQ4cKMplMuHDhglHbR0dHC66ursL9+/c1Ze+8844AQPDy8hK6d+8urF+/Xpg9e7bg7Ows1K1bV8jNzTXqWLaOXWNkMEEQ8OOPP6J3794QBAHp6emaW0xMDDIzM3H06FEApeM4ysb4qNVq3Lt3DyUlJWjdurWmzoOef/55+Pj46DzuiBEjtO537NgRd+/eRVZWVpUxDx8+XGtAYceOHaFSqXDlyhUAQGJiIkpKSvDWW29pbff2229Xue8HPfvss/j999/L3caNG6dV7969e9i5cyf69euH7Oxszfm7e/cuYmJi8M8//+DGjRsASsc9lY3bUKlUuHv3LpydndGwYUOd59BQS5YsKRfvr7/+qnl869atsLOzw5tvvqkpk8lkBp+bhz04QPT+/fvIzMxEx44djX5OO3bsQFFREUaPHq01zmXYsGFwdXXFli1btOo7Ozvj1Vdf1dxXKBRo27YtLl++bNTxt27dCplMhnfeeUer/L333oMgCFrnFAA6d+6M8PBwvfcfGBioaSUEAFdXVwwaNAjHjh3TtA4Y+l6JjY0tN1DX0H0MGjRIq6UtMjISgiDgtdde06oXGRmJa9euoaSkBEDpTCq1Wo1+/fppfYb4+/ujfv365brPdfnhhx/QuHFjNGrUSGsfXbt2BYBy+zDknL///vs6/5cfvn3++ed67e9Ba9euxfLly/Hee++hfv36Bm8/Y8YM7NixAzNnztQaq1bW8uzv748tW7agX79+GDt2LJYtW4ZLly5h7dq1Bh9LDNg1Rga7c+cOMjIy8NVXX+Grr77SWef27duav1evXo3PP/8c586d0+qv1jVTSVdZmdq1a2vd9/DwAFD6Jerq6lppzJVtC0CTED08U8rT01NTVx+1atXSGrhY5vr161r3L168CEEQMGnSJEyaNEnnvm7fvo2goCCo1WosWLAAX3zxBZKTk6FSqTR1TDEYsm3btpUOlr5y5QoCAgI03RFlGjZsWK3jbt68GZ988gmOHz+uNZ7jwYTVEGWv4cNxKRQKhIWFaR4vU6tWrXLH8vDwwN9//2308QMDA+Hi4qJV3rhxY634ylT2XtelXr165eJt0KABgNKxS/7+/ga/V3TFYOg+Hv7fKusmCg4OLleuVquRmZkJLy8v/PPPPxAEocJE4MHkqiL//PMPzp49W+GPpwc/hwDDznl4eLhBiaq+9uzZg6FDhyImJgaffvqpwduvX78eH330EYYOHar14wT478dFv379tH4MvPjiixg4cCD++uuvSgdmixUTITKYWq0GALz66qsVjhNp3rw5gNKBkIMHD0afPn0wbtw4+Pr6QiaTIT4+HpcuXSq3XWXTSCuaISQ8NBDV1NuaQ9k5HDt2LGJiYnTWKUvKZsyYgUmTJuG1117D9OnT4enpCalUitGjR2v286iQSCQ6z+mDX6ZA6ZfBM888g06dOuGLL75AQEAA5HI5Vq5cWWO/Wi39nqjsvW4sQ98rumIwdB8Vnceqzq9arYZEIsGvv/6qs+7DibcuarUazZo1w9y5c3U+/nAyZsg5z8zMRH5+fpX1FAqF1ljGypw4cQLPPPMMmjZtio0bN8LOzrCv4N9//x2DBg1Cr169kJCQUO7xsvFbfn5+WuUymQxeXl6aH36kjYkQGczHxwcuLi5QqVQ6Wz8etHHjRoSFheGnn37S+jU7ZcoUc4dpkDp16gAobal58Ffj3bt3zfLhERYWBqD0V68+5/CJJ57A8uXLtcozMjK0Bkoa25JSlTp16iAxMRE5OTlaX07nz58vV9fDw0Nn19LDrSE//vgj7O3tsX37dq0lD1auXFluW32fV9lreP78ec35BUpnLiYnJ1d5nqurTp062LFjB7Kzs7Vahc6dO6cVn7HKWhEfPB8XLlwAAM0gWH3fK5UxxT70UbduXQiCgNDQUE3LVkUqeg/UrVsXJ06cQLdu3Uz+/n/33XexevXqKut17txZM2uwMpcuXUL37t3h6+uLrVu36pXoPejAgQPo27cvWrdujQ0bNuhMoiIiIgBA061epqioCOnp6RW2nIkdxwiRwWQyGZ5//nn8+OOPOHXqVLnHH5yWXvZL78Ff2QcOHEBSUpL5AzVAt27dYGdnh6VLl2qVL1682CzH8/X1RZcuXfDll1/i1q1b5R5/+Bw+3Erxww8/lPuwc3JyAgC9Vqs1RM+ePVFSUqJ1blQqFRYtWlSubt26dXHu3Dmt+E+cOIF9+/Zp1ZPJZJBIJFotRSkpKTpXkHZyctLrOUVHR0OhUGDhwoVa52v58uXIzMzUzEQzl549e0KlUpV7z8ybNw8SiaTCmWj6unnzJn7++WfN/aysLHzzzTdo2bIl/P39Aej/XqmMKfahj+eeew4ymQxTp04tdzxBEHD37l3NfScnJ2RmZpbbR79+/XDjxg0sW7as3GP5+fnIzc01Oj5TjhFKTU3FU089BalUiu3bt1eakFy6dKlca/nZs2fRq1cvhISEYPPmzRW2bHXp0gW+vr5Ys2aN1oy3VatWQaVSaWYXkja2CFGFVqxYoXOdjXfffRczZ87Erl27EBkZiWHDhiE8PBz37t3D0aNHsWPHDty7dw8A8PTTT+Onn35C37590atXLyQnJyMhIQHh4eE6p5Rbip+fH9599118/vnneOaZZ9C9e3ecOHECv/76K7y9vc3S2rJkyRJ06NABzZo1w7BhwxAWFoa0tDQkJSXh+vXrmnVbnn76aUybNg1DhgxBu3btcPLkSaxZs0ar1QMoTULc3d2RkJAAFxcXODk5ITIysspxEb/++qum1eJB7dq1Q1hYGHr37o327dtj/PjxSElJ0axfo+uL6bXXXsPcuXMRExODoUOH4vbt20hISECTJk20BrX36tULc+fORffu3fHKK6/g9u3bWLJkCerVq1dujE5ERAR27NiBuXPnIjAwEKGhoYiMjCx3bB8fH0yYMAFTp05F9+7d8cwzz+D8+fP44osv0KZNG62B0ebQu3dvPPHEE5g4cSJSUlLQokUL/Pbbb/jf//6H0aNHVzrtXx8NGjTA0KFDcejQIfj5+WHFihVIS0vTakXT971SGVPsQx9169bFJ598ggkTJiAlJQV9+vSBi4sLkpOT8fPPP2P48OEYO3YsgNL3wPr16xEXF4c2bdrA2dkZvXv3xsCBA7FhwwaMGDECu3btQvv27aFSqXDu3Dls2LAB27dvN3qxUFOOEerevTsuX76M999/H3v37sXevXs1j/n5+WklKN26dQMAzfpY2dnZiImJwf379zFu3Lhyg/7r1q2LqKgoAKUD3WfPno3Y2Fh06tQJAwcOxNWrV7FgwQJ07NgRzz33nEmej82p2UlqZA0qm1YNQLh27ZogCIKQlpYmjBw5UggODhbkcrng7+8vdOvWTfjqq680+1Kr1cKMGTOEOnXqCEqlUmjVqpWwefPmctOpy6bPz549u1w8ZdPn79y5ozPOsunTglDx9PmHlwIom/q7a9cuTVlJSYkwadIkwd/fX3BwcBC6du0qnD17VvDy8hJGjBhR5XkDIIwcOVLnYxXFcenSJWHQoEGCv7+/IJfLhaCgIOHpp58WNm7cqKlTUFAgvPfee0JAQIDg4OAgtG/fXkhKStI5Vf1///ufEB4eLtjZ2VU5lb6q1/nBbe/evSsMHDhQcHV1Fdzc3ISBAwcKx44d03mM7777TggLCxMUCoXQsmVLYfv27Tqnzy9fvlyoX7++oFQqhUaNGgkrV67UvNYPOnfunNCpUyfBwcFBaykDXa+/IJROl2/UqJEgl8sFPz8/4c0339SaYiwIuqe/C4Luaf66VLR9dna2MGbMGCEwMFCQy+VC/fr1hdmzZ2st3SAIlb9XdKlTp47Qq1cvYfv27ULz5s015+yHH37Qqqfve6Xs/f/w9qbYR0Xv9Yr+j3/88UehQ4cOgpOTk+Dk5CQ0atRIGDlypHD+/HlNnZycHOGVV14R3N3dBQBar1FRUZHw2WefCU2aNBGUSqXg4eEhRERECFOnThUyMzM19Qw956ZU2f/Zw//DderU0fnZWNFN19IS33//vdCiRQtBqVQKfn5+wqhRo4SsrCzzPkkrJhEEC40WJbICGRkZ8PDwwCeffIKJEydaOpxHSkpKCkJDQ7Fy5UqzrbxNpUJCQtC0aVNs3rzZ0qEQ2RyOESL6l64ZImVXvNa1RD4REVk/jhEi+tf69euxatUq9OzZE87Ozti7dy++//57PPXUU5qrORMRkW1hIkT0r+bNm8POzg6zZs1CVlaWZgD1J598YunQiIjITDhGiIiIiESLY4SIiIhItJgIERERkWhxjFAV1Go1bt68CRcXF7NdwoCIiIhMSxAEZGdnIzAwUOsitA9jIlSFmzdvlrtwHxEREVmHa9euoVatWhU+zkSoCmUXT7x27RpcXV0tHA0RERHpIysrC8HBwVoXQdaFiVAVyrrDXF1dmQgRERFZmaqGtXCwNBEREYkWEyEiIiISLSZCREREJFocI0REVketVqOoqMjSYRCRBcnlcshksmrvh4kQEVmVoqIiJCcnQ61WWzoUIrIwd3d3+Pv7V2udPyZCRGQ1BEHArVu3IJPJEBwcXOkiaURkuwRBQF5eHm7fvg0ACAgIMHpfTISIyGqUlJQgLy8PgYGBcHR0tHQ4RGRBDg4OAIDbt2/D19fX6G4y/pwiIquhUqkAAAqFwsKRENGjoOwHUXFxsdH7YCJERFaH1/0jIsA0nwVMhIiIiEi0rCoR+vPPP9G7d28EBgZCIpFg06ZNVW6ze/duPPbYY1AqlahXrx5WrVpl9jiJiIjIOlhVIpSbm4sWLVpgyZIletVPTk5Gr1698MQTT+D48eMYPXo0Xn/9dWzfvt3MkRIREZE1sKpEqEePHvjkk0/Qt29fveonJCQgNDQUn3/+ORo3boxRo0bhhRdewLx588wcKRGRZXXp0gWjR48u97cl47AG1hLv3bt34evri5SUFEuHYjYvvfQSPv/8c7Mfx6oSIUMlJSUhOjpaqywmJgZJSUkVblNYWIisrCytGxGRNfvpp58wffp0vetbSzLwqFu6dCmaN28OV1dXuLq6IioqCr/++qtJ9v3pp5/i2WefRUhIiEn2V5nBgwdDIpGUu128eFHr8ZkzZ2ptt2nTJp2Dma9du4bXXnsNgYGBUCgUqFOnDt59913cvXtXq95HH32ETz/9FJmZmeZ7crDxRCg1NRV+fn5aZX5+fsjKykJ+fr7ObeLj4+Hm5qa5BQcH10SoZnc/twj3cnlJAiJrYcpLiHh6esLFxcVk+6P/dOnSpcKxp7Vq1cLMmTNx5MgRHD58GF27dsWzzz6L06dPV+uYeXl5WL58OYYOHVqt/eij7H3YvXt33Lp1S+sWGhqqqWdvb4/PPvsM9+/fr3R/ly9fRuvWrfHPP//g+++/x8WLF5GQkIDExERERUXh3r17mrpNmzZF3bp18d1335nnyf3LphMhY0yYMAGZmZma27Vr1ywdkkkUq9S4cT8fNzPyIQiCpcMhEp0uXbpg1KhRGDVqFNzc3ODt7Y1JkyZp/h/LHh89ejS8vb0RExMDoPS6avHx8QgNDYWDgwNatGiBjRs3au07NzcXgwYNgrOzMwICAsp1JzzcwqNWqzFr1izUq1cPSqUStWvXxqeffgqg9Nf9H3/8gQULFmh++aekpJgkDl327t0LuVyOgoICTVlKSgokEgmuXLmic5tt27ahQ4cOcHd3h5eXF55++mlcunSp3HN+55138P7778PT0xP+/v74+OOPqx2vIXr37o2ePXuifv36aNCgAT799FM4Oztj//79WvX279+Pbt26wcvLq1yri65eia1bt0KpVOLxxx/XlH311VcIDAwsd+mZZ599Fq+99hoA/c+brvehUqmEv7+/1u3BBQyjo6Ph7++P+Pj4Ss/JyJEjoVAo8Ntvv6Fz586oXbs2evTogR07duDGjRuYOHFiuXO4bt26SvdZXTadCPn7+yMtLU2rLC0tDa6urpoVKR+mVCo1zZhlN1tQoi79sL2bU4Tk9FyUqHidJqKatnr1atjZ2eHgwYNYsGAB5s6di6+//lrrcYVCgX379iEhIQFAaSv1N998g4SEBJw+fRpjxozBq6++ij/++EOz3bhx4/DHH3/gf//7H3777Tfs3r0bR48erTCOCRMmYObMmZg0aRLOnDmDtWvXalrPFyxYgKioKAwbNkzzyz84ONgscQDA8ePH0bhxY9jb22vKjh07Bg8PD9SpU0fnNrm5uYiLi8Phw4eRmJgIqVSKvn37lksCVq9eDScnJxw4cACzZs3CtGnT8Pvvv1crXmOpVCqsW7cOubm5iIqK0pSfOHECXbp0QatWrbBnzx5s27YNnp6e6NatG9avX6/zO2jPnj2IiIjQKnvxxRdx9+5d7Nq1S1N27949bNu2DQMGDABg2Hl7+H1YFZlMhhkzZmDRokW4fv26zjr37t3D9u3b8dZbb5X7Dvb398eAAQOwfv16rR/rbdu2xcGDB1FYWKhXHMaw6UtsREVFYevWrVplv//+u9abUCxU6v/eWLmFKly6k4swHyfIZTadC5NIfL3nMr7ek1xlvaZBrvg6to1W2eurD+HUjarHAr7eMRSvdwwzOkYACA4Oxrx58yCRSNCwYUOcPHkS8+bNw7BhwwAA9evXx6xZszT1CwsLMWPGDOzYsUPzuRUWFoa9e/fiyy+/ROfOnZGTk4Ply5fju+++Q7du3QCUfpHVqlVLZwzZ2dlYsGABFi9ejNjYWABA3bp10aFDBwCAm5sbFAoFHB0d4e/vb7Y4ypw4cQKtWrXSKjt+/DhatGhR4TbPP/+81v0VK1bAx8cHZ86cQdOmTTXlzZs3x5QpUwCUntvFixcjMTERTz75pNHxzpgxAzNmzNDcz8/Px/79+zFq1ChN2ZkzZ1C7dm0AwMmTJxEVFYWCggI4Ozvj559/Rnh4uKbuO++8g+eeew5z5swBAISHh+Pll1/GkSNH0K9fP50xXLlyBYGBgVplHh4e6NGjB9auXat5Phs3boS3tzeeeOIJg87bw+9DANi8eTOcnZ0193v06IEffvhBq07fvn3RsmVLTJkyBcuXLy8X9z///ANBENC4cWOdz6tx48a4f/8+7ty5A19fXwBAYGAgioqKkJqaWmFiXF1WlQjl5ORoBmcBpdPjjx8/Dk9PT9SuXRsTJkzAjRs38M033wAARowYgcWLF+P999/Ha6+9hp07d2LDhg3YsmWLpZ6CxRQ/1AJUVKJGYYmaiRDZhOyCEqRmFVRZL8DdvlzZ3dwivbbNLigxKrYHPf7441qDR6OiovD5559rLh3y8K/8ixcvIi8vD08++aRWeVFRkSZ5uHTpEoqKihAZGal53NPTEw0bNtQZw9mzZ1FYWKj5stSHOeIoc/z4cbzyyitaZceOHUPLli0r3Oaff/7B5MmTceDAAaSnp2taNK5evVouEXpQQECA5iKdxsY7YsQIrQRlwIABeP755/Hcc89pyh5MUho2bIjjx48jMzMTGzduRGxsLP744w+Eh4cjLS0Ne/fu1WpVAwAnJ6dKV0zOz8/XakF7MJZhw4bhiy++gFKpxJo1a/DSSy9pLk6s73l7+H0IAE888QSWLl2qFaMun332Gbp27YqxY8dWGL8hwzPKWo7y8vL03sZQVpUIHT58WJPZAkBcXBwAIDY2FqtWrcKtW7dw9epVzeOhoaHYsmULxowZgwULFqBWrVr4+uuvNX2eYvJgi1BlZUTWyMXeDv6u5b8YHublVP4aZV5OCr22dbE3/8flw18uOTk5AIAtW7YgKChI6zGlUmnUMSoaFlAZc8QBlHYXnTp1qlyL0NGjR8u1Xjyod+/eqFOnDpYtW6YZF9O0adNyA8zlcrnWfYlEUq4byFCenp7w9PTU3HdwcICvry/q1auns75CodA8FhERgUOHDmHBggX48ssvceTIEajV6nKtX0eOHEHr1q0rjMHb21vnoOTevXtDEARs2bIFbdq0wZ49e7SWi9H3vOlKcpycnCp8jg/q1KkTYmJiMGHCBAwePFjrsXr16kEikeDs2bM6l8E5e/YsPDw84OPjoykrGzz9YJmpWVUi1KVLl0ozSV0j97t06YJjx46ZMSrrUKwqf97UTITIRrzeMczobquHu8rM6cCBA1r39+/fj/r161d41ezw8HAolUpcvXoVnTt31lmnbt26kMvlOHDggKY75v79+7hw4YLOberXrw8HBwckJibi9ddf17lPhUKhaaUyVxwAcP78eRQUFGi1oCQlJeHGjRsVtgjdvXsX58+fx7Jly9CxY0cApQOuDWVMvKagVqs1413KkrLc3FzNrL6///4bf/75Jz755JMK99GqVSudM6ns7e3x3HPPYc2aNbh48SIaNmyIxx57DIDpzps+Zs6ciZYtW5ZrXfPy8sKTTz6JL774AmPGjNFKylNTU7FmzRoMGjRIqzXs1KlTqFWrFry9vc0SK2BliRAZT2eLEGePEdWoq1evIi4uDm+88QaOHj2KRYsWVTpTycXFBWPHjsWYMWOgVqvRoUMHZGZmYt++fXB1dUVsbCycnZ0xdOhQjBs3Dl5eXvD19cXEiRM13SEPs7e3xwcffID3338fCoUC7du3x507d3D69GnNdOyQkBAcOHAAKSkpcHZ2hqenp8njAEq7xQBg0aJFeOedd3Dx4kW88847ACpePsDDwwNeXl746quvEBAQgKtXr2L8+PH6nH4txsQLlLaOlbWQAdDMaEpNTdWU+fj4QCaTYcKECejRowdq166N7OxsrF27Frt379Zc3SAyMhIODg4YN24cJk6ciEuXLmHkyJEYOXKk1oywh5W1uNy/fx8eHh5ajw0YMABPP/00Tp8+jVdffVVTbqrzpo9mzZphwIABWLhwYbnHFi9ejHbt2iEmJgaffPIJQkNDcfr0aYwbNw5BQUGa2Ytl9uzZg6eeesoscZZhIiQCFc0QY4sQUc0aNGgQ8vPz0bZtW8hkMrz77rsYPnx4pdtMnz4dPj4+iI+Px+XLl+Hu7o7HHnsMH374oabO7NmzkZOTg969e8PFxQXvvfdepYvQTZo0CXZ2dpg8eTJu3ryJgIAAjBgxQvP42LFjERsbi/DwcOTn5yM5OdkscRw/fhwxMTG4fPkymjVrhvDwcEydOhVvvvkmFi5ciG+//bbcNlKpFOvWrcM777yDpk2bomHDhli4cCG6dOlS6XnUxdB4AWDOnDmYOnVqpXWSk5MREhKC27dvY9CgQbh16xbc3NzQvHlzbN++XTPWysfHBxs2bMB7772H5s2bo3bt2hg1apRm2EdFmjVrhsceewwbNmzAG2+8ofVY165d4enpifPnz2uNvTLledPHtGnTsH79+nLl9evXx+HDhzFlyhT069cP9+7dg7+/P/r06YMpU6ZodTsWFBRg06ZN2LZtm1liLCMRuKhMpbKysuDm5obMzEyrnUpfUKzCP2k55cq9XRQIcDN8vACRpRQUFCA5ORmhoaE6B4s+yrp06YKWLVti/vz5lg7lkRETE4M2bdpU2g1Eum3ZsgXjxo3DqVOnqmzFslZLly7Fzz//jN9++63COpV9Juj7/W2bZ4+0lFTQ8sPB0kRkSSdOnECzZs0sHYZV6tWrF4YPH44bN25YOhSzkcvlWLRokdmPw64xEVDpGCgNANWcPEFEZLTU1FSkpaUxEaoGW78eXEWD+U2NiZAIlFSQ8VRUTkSmt3v3bkuH8Ejx9/fn5X7okcCuMRGoqGtMzQ8hIiISOSZCIlDxGKEaDoSIiOgRw0RIBCoaI8TB0kREJHZMhESguIKxQOwaIyIisWMiJAIVtfwIAhdVJCIicWMiJAIlFXSNAbzMBhERiRsTIRsnCEKlY4E4ToiIiMSMiZCNq2jGWBmOEyIiIjFjImTjqmrxYYsQERGJGRMhG1dcxWJBXFyayPy6dOli85dDILJWTIRsXJUtQuwaI7I4QRBQUlJi6TCIRImJkI2raowQu8aIzGvw4MH4448/sGDBAkgkEkgkEqxatQoSiQS//vorIiIioFQqsXfvXgwePBh9+vTR2n706NHo0qWL5r5arUZ8fDxCQ0Ph4OCAFi1aYOPGjTX7pIhsCC+6auMqmzoPcLA0WTdBEJBXnGeRYzvKHSGRSKqst2DBAly4cAFNmzbFtGnTAACnT58GAIwfPx5z5sxBWFgYPDw89DpufHw8vvvuOyQkJKB+/fr4888/8eqrr8LHxwedO3c2/gkRiRQTIRtX1RXm2SJE1iyvOA/O8c4WOXbOhBw4KZyqrOfm5gaFQgFHR0f4+/sDAM6dOwcAmDZtGp588km9j1lYWIgZM2Zgx44diIqKAgCEhYVh7969+PLLL5kIERmBiZCN46wxokdX69atDap/8eJF5OXllUueioqK0KpVK1OGRiQaTIRsXDG7xsiGOcodkTMhx2LHri4nJ+0WJalUCuGh/8ni4mLN3zk5pc91y5YtCAoK0qqnVCqrHQ+RGDERsnFsESJbJpFI9OqesjSFQgGVSlVlPR8fH5w6dUqr7Pjx45DL5QCA8PBwKJVKXL16ld1gRCbCRMjGVTVGiC1CROYXEhKCAwcOICUlBc7OzlBX8H/ZtWtXzJ49G9988w2ioqLw3Xff4dSpU5puLxcXF4wdOxZjxoyBWq1Ghw4dkJmZiX379sHV1RWxsbE1+bSIbAKnz9swtVqocsHEKtZbJCITGDt2LGQyGcLDw+Hj44OrV6/qrBcTE4NJkybh/fffR5s2bZCdnY1BgwZp1Zk+fTomTZqE+Ph4NG7cGN27d8eWLVsQGhpaE0+FyOZIhIc7pElLVlYW3NzckJmZCVdXV0uHY5CiEjXOp2ZXWkcmlSA80LqeF4lXQUEBkpOTERoaCnt7e0uHQ0QWVtlngr7f32wRsmH6jP9h1xgREYkZEyEbVqzHhcQEobQLjYiISIyYCNkwVRVT5zX12CpEREQixUTIhlV1nbEynEJPRERixUTIhlU1db4MxwmRteEcDyICTPNZwETIhlV1wdUybBEiayGTyQCUXlKCiCgvr/Siy2WLjhqDCyraMH0THD0bjogszs7ODo6Ojrhz5w7kcjmkUv6WIxIjQRCQl5eH27dvw93dXfMjyRhMhGyYvl1jHCxN1kIikSAgIADJycm4cuWKpcMhIgtzd3eHv79/tfZhdYnQkiVLMHv2bKSmpqJFixZYtGgR2rZtW2H9+fPnY+nSpbh69Sq8vb3xwgsvID4+XhSLsXGwNNkihUKB+vXrs3uMSOTkcnm1WoLKWFUitH79esTFxSEhIQGRkZGYP38+YmJicP78efj6+parv3btWowfPx4rVqxAu3btcOHCBQwePBgSiQRz5861wDOoWfqOEeJgabI2UqlUFD9miMj8rKqDfe7cuRg2bBiGDBmC8PBwJCQkwNHREStWrNBZ/6+//kL79u3xyiuvICQkBE899RRefvllHDx4sIYjr3kqtQB98xu2CBERkVhZTSJUVFSEI0eOIDo6WlMmlUoRHR2NpKQkndu0a9cOR44c0SQ+ly9fxtatW9GzZ88aidmS9B0fBDARIiIi8bKarrH09HSoVCr4+flplfv5+eHcuXM6t3nllVeQnp6ODh06QBAElJSUYMSIEfjwww8rPE5hYSEKCws197OyskzzBGqYIckNu8aIiEisrKZFyBi7d+/GjBkz8MUXX+Do0aP46aefsGXLFkyfPr3CbeLj4+Hm5qa5BQcH12DEplOs5/gggC1CREQkXlbTIuTt7Q2ZTIa0tDSt8rS0tAqnzk2aNAkDBw7E66+/DgBo1qwZcnNzMXz4cEycOFHnGiQTJkxAXFyc5n5WVpZVJkNsESIiIqqa1bQIKRQKREREIDExUVOmVquRmJiIqKgondvk5eWVS3bKptpVtCy3UqmEq6ur1s0aGTZGyIyBEBERPcKspkUIAOLi4hAbG4vWrVujbdu2mD9/PnJzczFkyBAAwKBBgxAUFIT4+HgAQO/evTF37ly0atUKkZGRuHjxIiZNmoTevXubZO2BR5m+U+cBw5ImIiIiW2JViVD//v1x584dTJ48GampqWjZsiW2bdumGUB99epVrRagjz76CBKJBB999BFu3LgBHx8f9O7dG59++qmlnkKNMaRrjD1jREQkVhKBl3GuVFZWFtzc3JCZmWlV3WTJ6bnIKSjRu36TQFdIpRIzRkRERFRz9P3+tpoxQmSYEgMH/vB6Y0REJEZMhGyUvtcZK8Mp9EREJEZMhGyQIAgGDZYGmAgREZE4MRGyQYa2BgHsGiMiInFiImSDio1YGEjNFiEiIhIhJkI2yJDLa5Rh1xgREYkREyEbZOiMMYBdY0REJE5MhGyQMS1CXFyaiIjEiImQDTJmjBBbhIiISIyYCNkgY2aNcbA0ERGJERMhG2RUixATISIiEiEmQjaIXWNERET6YSJkY9RqwaiBz+waIyIiMWIiZGOKjGgNAtgiRERE4sREyMYYM1Aa4BghIiISJyZCNsaYxRQBriNERETixETIxhjbNQawVYiIiMSHiZCNKTFiVekyTISIiEhsmAjZmOokQmoOmCYiIpFhImRj2DVGRESkPyZCNqakGqOeOYWeiIjEhomQjalW1xhbhIiISGSYCNmQEpUa1WnUYdcYERGJDRMhG1JcjdYggF1jREQkPkyEbEhxNVdF5KKKREQkNkyEbEh1xgcBbBEiIiLxYSJkQ4qrMXUe4GBpIiISHyZCNqS6iRAHSxMRkdgwEbIh7BojIiIyDBMhG8KuMSIiIsMwEbIhnD5PRERkGCZCNkIQhGqP8eEYISIiEhsmQjaiOhdbLcMGISIiEhsmQjaiugOlgdJEiOOEiIhITKwuEVqyZAlCQkJgb2+PyMhIHDx4sNL6GRkZGDlyJAICAqBUKtGgQQNs3bq1hqKtOaZIhACOEyIiInGxs3QAhli/fj3i4uKQkJCAyMhIzJ8/HzExMTh//jx8fX3L1S8qKsKTTz4JX19fbNy4EUFBQbhy5Qrc3d1rPngzM0XXGAComQgREZGISATBer75IiMj0aZNGyxevBgAoFarERwcjLfffhvjx48vVz8hIQGzZ8/GuXPnIJfLjTpmVlYW3NzckJmZCVdX12rFb063MvORnl1U7f3U83WGg0JmgoiIiIgsR9/vb6vpGisqKsKRI0cQHR2tKZNKpYiOjkZSUpLObX755RdERUVh5MiR8PPzQ9OmTTFjxgyoVKqaCrvGsGuMiIjIcFbTNZaeng6VSgU/Pz+tcj8/P5w7d07nNpcvX8bOnTsxYMAAbN26FRcvXsRbb72F4uJiTJkyRec2hYWFKCws1NzPysoy3ZMwI1N1jXEKPRERiYnVtAgZQ61Ww9fXF1999RUiIiLQv39/TJw4EQkJCRVuEx8fDzc3N80tODi4BiM2nqlahDhrjIiIxMRqEiFvb2/IZDKkpaVplaelpcHf31/nNgEBAWjQoAFksv/GvDRu3BipqakoKtI9nmbChAnIzMzU3K5du2a6J2FG1b28RhkOliYiIjGxmkRIoVAgIiICiYmJmjK1Wo3ExERERUXp3KZ9+/a4ePEi1Or/koQLFy4gICAACoVC5zZKpRKurq5at0edSi2YbDFEjhEiIiIxsZpECADi4uKwbNkyrF69GmfPnsWbb76J3NxcDBkyBAAwaNAgTJgwQVP/zTffxL179/Duu+/iwoUL2LJlC2bMmIGRI0da6imYhalagwBAbbpdERERPfKsZrA0APTv3x937tzB5MmTkZqaipYtW2Lbtm2aAdRXr16FVPpfbhccHIzt27djzJgxaN68OYKCgvDuu+/igw8+sNRTMAtTJkJsESIiIjGxqnWELMEa1hG6n1uE6/fzTbIvd0c5gj0dTbIvIiIiS7G5dYSoYoUlJuwaY15MREQiwkTIBmQVFJtsX1xHiIiIxISJkJUrLFGhsJgtQkRERMZgImTlsvJLTLo/NggREZGYMBGyctkm7BYD2DVGRETiwkTIipWo1MgrMu0FZJkIERGRmDARsmI5hSUmW1G6jCAAXFGBiIjEgomQFTP1+KAybBQiIiKxYCJkpQRBQHahaccHlWH3GBERiQUTISuVU1hituuCcQo9ERGJBRMhC7mffx/z98+H2shsJqvAPN1iAFuEiIhIPJgIWUCRqggdV3bEmO1j0GLu2ygoNnzml6mnzT+ILUJERCQWTIQsQCFToIFTXwDAqdwvMGjtIoO2zy9SobjEfMmKubrciIiIHjVMhCwkPuZ9uKueAQBsTB6PxXu36b2tOVuDAEDFFiEiIhIJJkIW0tDfBfNiPoeDqjUESSHGJL6Cw9cu6LWtKS+yqgvHCBERkVgwEbKg2HZ18WLoHMjVISjBfXRb3RP38jIq3SYzrxj5Rebtu+KCikREJBZMhCxIIpFgfr8oNJZ/CpnggSzVJUR9+SyKSnTPCMvIK8LVe3lmj4tdY0REJBZMhCzMw0mBhJdj4Fs4GRJBiQtZf6LfutEoLNGeSXYvtwjX7uXXSEzsGiMiIrFgIvQIiKrrhXc6xcCr+G0AwP8uLcHHv61AXlFpy9DdnELcuF8zSRDAWWNERCQeTIQeEeNiGiEm9AW4FPcBAMw/Ohrbzx/C9ft5uJlRUKOxsGuMiIjEgonQI0ImlWDJK4+hf4MP8JhfRxSo8vDu7wNw5d6dGo+FCyoSEZFYMBF6hLg5yjH7hccwL3oVAp2DcS07GR/sfh0qteErT1eHmmOEiIhIJJgIPWLcHRUI9fTH/Og1sJc5YN/1HXh7y4coVtXcwB12jRERkVgwEXoEBbo7oLlfS0xsNx8AsDftK4ze9G25mWTmwlljREQkFkyEHkEyqQTBng5o4t4TrqoYQCJg3/0p+PCXPcgvMn8yxAYhIiISCyZCjyhHhR3a1/PG/O4LoBRCoJZkYE/6ZEz+3wnkFupecNFUBIHjhIiISByYCD3CfFyU6FgvCDO7rIBEsEeB7G/sv7scH206hYJi87YMcZwQERGJAROhR5hEIoGPsxLd6j2Gdx6bCQDItPsep+7+hcSzaWY9NscJERGRGDAResS52NtBKgWGPjYYXWr1AyRqpCvmYPc/l816XDYIERGRGDAResRJpRK42ssBADO7zoMDakMluYe9d2ciI6/IbMdl1xgREYkBEyEr4OZYmgg5yp3wXEg8IMiQK9uLLw+tNtsx2TVGRERiwETICrgo7SCTSgAAzzXthOaurwEAfrn6CVJzb5jlmJw1RkREYsBEyApIJBJNq1B9PxeseiEeTX0ikF2UiSl/joJaMP2q07zeGBERiQETISvh5iDX/G0ntcOnnRKglNkj6eYubDi73OTH4xghIiISAyZCVsJJIYOdTKK5H+peH2PaTAUAzD04GckZ/5j0eOqau7QZERGRxRicCJ09exZTpkxB165dUbduXQQEBKB58+aIjY3F2rVrUVhYaI44NZYsWYKQkBDY29sjMjISBw8e1Gu7devWQSKRoE+fPmaNz1wkEgncHf9rFSooVqGO/XMItG+LAlU+Pt77jkm7yNgiREREYqB3InT06FFER0ejVatW2Lt3LyIjIzF69GhMnz4dr776KgRBwMSJExEYGIjPPvvMLAnR+vXrERcXhylTpuDo0aNo0aIFYmJicPv27Uq3S0lJwdixY9GxY0eTx1STHuweW743GTN/PQ9JxggoZY44lpaEny98Z7JjcbA0ERGJgUQQ9PvpHxoainHjxuGVV16Bu7t7hfWSkpKwYMECNG/eHB9++KGp4gQAREZGok2bNli8eDEAQK1WIzg4GG+//TbGjx+vcxuVSoVOnTrhtddew549e5CRkYFNmzbpfcysrCy4ubkhMzMTrq6upnga1XI+NRtFJWocvnIPU//vDADAN+B3HMpYABeFG3554RC8HHyrfRxXBzvU8XKq9n6IiIgsQd/vbzt9d3jhwgXI5fIq60VFRSEqKgrFxcX67lovRUVFOHLkCCZMmKApk0qliI6ORlJSUoXbTZs2Db6+vhg6dCj27NlT5XEKCwu1WrOysrKqF7iJuTvKcTurEC1qucNRIUNekQo5955EI5/dOHf3BGYfmIiZXZZV+zhcR4iIiMRA764xfZKg6tSvSnp6OlQqFfz8/LTK/fz8kJqaqnObvXv3Yvny5Vi2TP/EID4+Hm5ubppbcHBwteI2tbLuMblMirYhngCAvELgpbrTIZVIsfXSD9h3PbHax+H0eSIiEgODB0unp6dj1qxZ6Nu3r6b1p2/fvpg9ezbu3LljjhiNkp2djYEDB2LZsmXw9vbWe7sJEyYgMzNTc7t27ZoZozScvVwGyb+Tx6LqemnKb6YH4uXw4QCAT/6KQ35JXrWOo+KsMSIiEgGDEqFDhw6hQYMGWLhwIdzc3NCpUyd06tQJbm5uWLhwIRo1aoTDhw+bJVBvb2/IZDKkpWlfdT0tLQ3+/v7l6l+6dAkpKSno3bs37OzsYGdnh2+++Qa//PIL7OzscOnSJZ3HUSqVcHV11bo9auSy0pftsdoeUNiV/n0g+S5GPvYh/JyCcCP7Cr46Nrtax2CLEBERiYHeY4QA4O2338aLL76IhIQESCQSrccEQcCIESPw9ttvVzpmx1gKhQIRERFITEzUTIFXq9VITEzEqFGjytVv1KgRTp48qVX20UcfITs7GwsWLHjkurwMIZdJUFRS2jrUPMgNh6/cR0ZeMe7n2OHDqNl4d8crWH1yEXrXfwlh7g2NOgbHCBERkRgY1CJ04sQJjBkzplwSBJSuczNmzBgcP37cVLGVExcXh2XLlmH16tU4e/Ys3nzzTeTm5mLIkCEAgEGDBmkGU9vb26Np06ZaN3d3d7i4uKBp06ZQKBRmi9PcylqEAKBZkJvm71M3M/FEnZ7oXLs7SoQSzD042ehjCEJpcktERGTLDEqE/P39K13A8ODBg+UGM5tS//79MWfOHEyePBktW7bE8ePHsW3bNs0xr169ilu3bpnt+I+KBxOhpg8mQjcyAQBxbabDTmKHP69tx/6bfxh9HDYKERGRrdN7HSGgdFXn9957D2+88Qa6deumSUDS0tKQmJiIZcuWYc6cOXjrrbfMFnBNe9TWEQKA9JxC3MooAFDahTX1/06jgZ8LWtV2R5PA0sRoZtL7WHvmKzT0bIp1z/4BmVRm8HEa+rtoxiARERFZE32/vw1KhIDS1Z3nzZuHI0eOQKVSAQBkMhkiIiIQFxeHfv36VS/yR8yjmAhl5hXj6r3KZ4VlFNxDrx9aIbsoE1M7LkLfBgMNPk59P2fYyw1PoIiIiCzNbIlQmeLiYqSnpwMondFl6nWDHhWPYiKUV1SCS7dzq6y3+uQifH5wEnwc/fF/LxyGo9zZoOOE+TjBSWnQeHoiIqJHgr7f30b3e8jlcgQEBCAgIMBmk6BHlZ1Uv5ft5fDhqOUSgjt5qVh5cqHBx+GFV4mIyNaZdADIpUuX0LVrV1PuknSQy8rP2hMEAdfu5eH3M6ma2V4KmRJj2kwFAKz+exFSc28YdBwTXsyeiIjokWTSRCgnJwd//GH8LCXSj0Qigd1DydCs7efx1tqjWLjzIm5lFmjKo0OeQSu/KBSo8rH0aLxBx2GLEBER2TqDBoAsXFh598qNG4a1OJDx5DIJSlT/JSph3k7Ye7F0zNbJG5kIdHcAUJo0xbWdhoH/9yR++ed7DG85DkEudfQ6BhdVJCIiW2dQIjR69GgEBARUuBhhUVGRSYKiqpWOE/qv76rZQ+sJxTT577IjLXzbICrwCSTd3IUVf8/HpPbz9DoGL7NBRES2zqCusTp16mDevHlITk7WeduyZYu54qSHyB9a36eerzOU/5aduplZblXo4a3GAQA2XVij91ghJkJERGTrDEqEIiIicOTIkQofl0gkvCxDDZFLtccI2cmkaBxQOj0wPacIqVkFWo9H+LdDhH87FKuLsOrvBXodg11jRERk6wxKhKZNm4YXX3yxwsfDw8ORnJxc7aCoag9eZqOMrsttPOiNlu8DAH48/w3S89KqPIaas8aIiMjGGZQIhYeHo3Xr1hU+LpfLUaeOfgNxqXoenjUGaI8TOqkjEYoM7IzmPm1QqCrA6lOLqzwGZ40REZGt44WkrJSuFqH6vs6aa4OduplVrptSIpHgjX/HCm04uwL3C+5WegyOESIiIltn0kToww8/xGuvvWbKXVIFdCVCcpkUjf1dAAB3sguRll1Yrk6HWk+isVcL5Jfk4ttTX1R6DDXHCBERkY0zaSJ048YNpKSkmHKXVAGZVAJJ+d4xNAtyg7PSDpGhnigqKT/IRyKRYHjLsQCA7898hazCjAqPwa4xIiKydUZfdFUsHsWLrpY5n5pdLtkpLFFBLpNCqitL+pdaUOP5n9rjUsZZvB8Zj1ebvqmznlQKNAl00/kYERHRo8zsF10ly9N1zTGlnazSJAgApBIp+jceCgDYeH5VhUsecNYYERHZOoNWlgaA9PR0rFixAklJSUhNTQUA+Pv7o127dhg8eDB8fHxMHiTpVjpOSGXUtr3qvYi5hybjcsZ5HE1LQoR/O5311GoBUmnliRUREZG1MqhF6NChQ2jQoAEWLlwINzc3dOrUCZ06dYKbmxsWLlyIRo0a4fDhw+aKlR6ia8D0g/KKSrDl75vIyi8u95iLwg09wp4HAGw8t6rCfXCcEBER2TKDxgg9/vjjaNGiBRISEiB5qPtFEASMGDECf//9N5KSkkweqKU8ymOE0nMKcSujQOdj+y6mY0HiP8gvVmFIuxA891itcnVO3zmGl395AgqZEjteOgt3e89yder7OcNeLjN57EREROZkljFCJ06cwJgxY8olQUDpbKQxY8bg+PHjBgdLxpFLK375QryckF9c2m229dQtnZfLaOLTCo29WqBIVYhf/lmrcz9sECIiIltmUCLk7++PgwcPVvj4wYMH4efnV+2gSD9yu4rH7gR5OOCx2u4AgLSsQhy5cl9nvRcbDQEA/FDBoGl2jRERkS0zaLD02LFjMXz4cBw5cgTdunXTJD1paWlITEzEsmXLMGfOHLMESuXZVdIiBAC9mgXg6NUMAMCWkzfRNrR811ePsOcx5+BHuJJ5EYdu7UHbwE5aj/PCq0REZMsMSoRGjhwJb29vzJs3D1988QVUqtKuF5lMhoiICKxatQr9+vUzS6BUnq7p8w+KqOMJXxclbmcX4ujVDNy4n48gDwetOk4KF/Sq+yJ+OLcSs/YuxpjHwtG+nrfmcSZCRERkywxeR6h///7Yv38/8vLycOPGDdy4cQN5eXnYv38/k6AaJpFIdF58tYxMKkGvZgGa+xuOXEN2gfYMspzCEmTd7QwAuJC1E/Hb9+N29n8DsEu4mBAREdkwoxdUlMvlCAgIQEBAAORyuSljIgNU1SoU3dgPin+n2e88dxuvfH0Av5y4oXncUSFDfk5tKNT1AUkJsmQ7cPxahuZxtggREZEtMzoRmjlzJjIyMsr9TTWrqnFCrg5yPNVEewB7bU8nzd9SiQQvta0NL/QCAOTKduLEtUzN4yUqJkJERGS7jE6EZsyYgXv37pX7m2qW3K7ql/D1DmGY1Ksxnm0RiAZ+zmjo56L1eLu6Xvjm5bcBQYpi6VUcun5WM4OMLUJERGTLDL7ERpkHp1rzuq2WI9fj8hcyqQRtQ73QNtRL5+NSiQR+zl7wUbTAneJjSCv6C1fvPYk6Xk4oYSJEREQ2jBddtXJVXWbDEI/5RgMA8mUHceJ6afcYW4SIiMiWMRGycpXNGjNUn0a9AQAF0lM4crV0QDUTISIismVMhKycKVuEouo0g0IIAiQlOHl3DwRBgEotsOuTiIhsFhMhK2fKREgqkaBz7e4AgEahFzTXlGOrEBER2SqTfIvquggr1QyZVAJTnv6Xm/UFAOy7/jtU6tKVwzlgmoiIbJVJEiF2nViWKVuFWvhFwkXhhvsFd3HyzmEAbBEiIiLbZfQ36JkzZxASEqL5u06dOqaKqVJLlixBSEgI7O3tERkZiYMHD1ZYd9myZejYsSM8PDzg4eGB6OjoSutbq6pWlzZoX1I5OtR6EgDwx9VtANgiREREtsvoRCg4OBjSf1c1Dg4OhkwmM1lQFVm/fj3i4uIwZcoUHD16FC1atEBMTAxu376ts/7u3bvx8ssvY9euXUhKSkJwcDCeeuop3LhxQ2d9a2XKFiEAaO7dFQCw4fQmnL6ZyRYhIiKyWUZ9g8pkMp3Jx927d82aEM2dOxfDhg3DkCFDEB4ejoSEBDg6OmLFihU6669ZswZvvfUWWrZsiUaNGuHrr7+GWq1GYmKi2WK0BFMnQu7SNoAgRbYqGbsunuKFV4mIyGYZ9Q1a0ZigwsJCKBSKagVUkaKiIhw5cgTR0dGaMqlUiujoaCQlJem1j7y8PBQXF8PT07PCOoWFhcjKytK6PepkeqwubYjH69SGUh0OAPjz2na2CBERkc0y6BIbCxcuBFA6S+zrr7+Gs7Oz5jGVSoU///wTjRo1Mm2E/0pPT4dKpYKfn/YFRP38/HDu3Dm99vHBBx8gMDBQK5l6WHx8PKZOnVqtWGuaqRMhL2clApTtkVJyCim5e5CVV4wANweTHoOIiOhRYFAiNG/ePAClLUIJCQla3WAKhQIhISFISEgwbYQmMnPmTKxbtw67d++Gvb19hfUmTJiAuLg4zf2srCwEBwfXRIhGk5lh+YKowKeQcvVLFEhPYt/l62gYEG7yYxAREVmaQYlQcnIyAOCJJ57ATz/9BA8PD7MEpYu3tzdkMhnS0tK0ytPS0uDv71/ptnPmzMHMmTOxY8cONG/evNK6SqUSSqWy2vHWJKkZlsXsHNYCP6QEokR6E7+c+w2vtWciREREtseor9Bdu3bVaBIElLY4RUREaA10Lhv4HBUVVeF2s2bNwvTp07Ft2za0bt26JkKtcabuGgOAJoFuUKjrAgAu3b9s8v0TERE9CkzeljBt2jTs2bPH1LsFAMTFxWHZsmVYvXo1zp49izfffBO5ubkYMmQIAGDQoEGYMGGCpv5nn32GSZMmYcWKFQgJCUFqaipSU1ORk5NjlvgsRWqGrjFXezvYS30AAHcLbpl8/0RERI8CkydCK1euRExMDHr37m3qXaN///6YM2cOJk+ejJYtW+L48ePYtm2bZgD11atXcevWf1/aS5cuRVFREV544QUEBARobnPmzDF5bJZkjhYhiUQCV0Xpec0pvg01p9ATEZENMmiMkD6Sk5ORn5+PXbt2mXrXAIBRo0Zh1KhROh/bvXu31v2UlBSzxPCoMcdgaQDwdghASjZQJKQjI68Yns7WNXaKiIioKma5+ryDgwN69uxpjl2TDlIztAgBQJvgegAAB4dMcCkhIiKyRUYlQh9//LHOrpLMzEy8/PLL1Q6KDGeOmWN9mjUDAOSU3IZSbp5ki4iIyJKM+vpcvnw5OnTogMuX/5tNtHv3bjRr1gyXLl0yWXCkP3OME/J29IcEEpSoi5Gao/t6bkRERNbMqETo77//Rq1atdCyZUssW7YM48aNw1NPPYWBAwfir7/+MnWMpAdzjBOSS+XwcvAFAFzPsq0L1RIREQFGDpb28PDAhg0b8OGHH+KNN96AnZ0dfv31V3Tr1s3U8ZGezDVOyMcxAOn5aThx6xKeCIs0yzGIiIgsxeiRJYsWLcKCBQvw8ssvIywsDO+88w5OnDhhytjIAOZoEcouKMaV26WXI/n24DGT75+IiMjSjEqEunfvjqlTp2L16tVYs2YNjh07hk6dOuHxxx/HrFmzTB0j6cEcY4SclXaQwwsAkJ7PRRWJiMj2GJUIqVQq/P3333jhhRcAlE6XX7p0KTZu3Ki5MCvVLHN0jT24qGJmUVoVtYmIiKyPUWOEfv/9d53lvXr1wsmTJ6sVEBnHTEOE4KkMwLU8oFB9B7mFJXBSmnwNTiIiIovRu0VIEPRbUc/b29voYMh45lpd2s8pEABQIrmL1KwCsxyDiIjIUvROhJo0aYJ169ahqKio0nr//PMP3nzzTcycObPawZH+zDVrLMilFgBAJbmL1EwmQkREZFv07udYtGgRPvjgA7z11lt48skn0bp1awQGBsLe3h7379/HmTNnsHfvXpw+fRqjRo3Cm2++ac646SHmahGq4x4EABAk+bh05zba12OLHxER2Q69E6Fu3brh8OHD2Lt3L9avX481a9bgypUryM/Ph7e3N1q1aoVBgwZhwIAB8PDwMGfMpIO5WoQC3DwgFZygluTi3J0UAOFmOQ4REZElGDzytUOHDujQoYM5YqFqMMf0eQDwclJAJnhBLcnFpfvXzHIMIiIiSzFqCtC0adMqfXzy5MlGBUPGM1fXmLezEjLBG8W4iuuZ181yDCIiIksxKhH6+eefte4XFxcjOTkZdnZ2qFu3LhMhCzDH1ecBwN1Rjsfr1MPu60fRqbHMPAchIiKyEKMSoWPHyl9uISsrC4MHD0bfvn2rHRQZzlwtQg4KGRr6hGD3dSA1lxdeJSIi22KydgRXV1dMnToVkyZNMtUuyQDmGiPkrLSDn2PpWkK8Aj0REdkak3aoZGZmIjMz05S7JD1JJBKYo1HI2d5Os6jirRwmQkREZFuM6hpbuHCh1n1BEHDr1i18++236NGjh0kCI8NJJRKo9FwBXB8SCeCssIOqpHQ5hOR713Dxdg7q+Tqb7BhERESWZFQi9PCFVaVSKXx8fBAbG4sJEyaYJDAynEwqgUptukRIJpVAKpUg7b4TACBPdR/HrqUxESIiIpthVCKUnJxs6jjIBGQmnjkml5X2tdX19IdEUECQFOHcnasA6pr2QERERBZipknXZAlSEw8Ssvt3Tn6wlwNkghcA4PI9LqpIRES2g4mQDTH1zLGy/QV7OGoSoauZTISIiMh2MBGyIaZuEZL/29cW7OkIu38TodScmyY9BhERkSUxEbIhpm4Rsvt3jJCjwg5Odn4AgLv5t0x6DCIiIktiImRDTN4i9MB1OzzsSxOh7JLbKFapTXocIiIiS2EiZENMfb0xmey/xMrfKQgAUCJJx53sQtMeiIiIyEKYCNkQU19vzO6BrrbabqWJkAr3cCuzwKTHISIishQmQjbE1GOE5A8sTNQ0IAwAoJbeg1TCrjEiIrINTIRsiNSEiZBEop1YjXmiDaQSKQSo4OfBrjEiIrINTIRsiCm7xuxk2vuSy+TwcfQHANzgVeiJiMhGMBGyIabsGrPTMfI6wLn0KvTXs66b7DhERESWxETIhphy+rxcVn5fQa61AAA3stkiREREtsHqEqElS5YgJCQE9vb2iIyMxMGDByut/8MPP6BRo0awt7dHs2bNsHXr1hqKtOaZskVI176u3bEHAMzflWSy4xAREVmSVSVC69evR1xcHKZMmYKjR4+iRYsWiImJwe3bt3XW/+uvv/Dyyy9j6NChOHbsGPr06YM+ffrg1KlTNRx5zTDlpDG5jkvZFxS4AQDScm9BrRZMdzAiIiILkQiCYDXfaJGRkWjTpg0WL14MAFCr1QgODsbbb7+N8ePHl6vfv39/5ObmYvPmzZqyxx9/HC1btkRCQoJex8zKyoKbmxsyMzPh6upqmidiRqduZMIUr2iguz28nJVaZZ2XTMef6ZOhVDXDgm4/w89VCUeFHST/dsm1ruMBDyeFpv7trAKcuJ5Z5bFkUqBrIz+tsjM3s3AjI7/KbX1clGgZ7K5VtvefdOQXq6rctpG/C4I9HTX3cwtL8Nelu1VuBwDt63nBUWGnuX/1bh7Op2VXuZ2jQob29by1yo5dvY/0nKIqtw32dEAjf+334I4zadDn5W4R7AZfF3vN/bs5hTh6NUOPLYFujXy1ZiReSMvGlbt5VW7n6SRHRB1PrbKkS3eRU1hS5bb1fZ0R4u2kuV9QrMKef9L1ijcyzBOu9nLN/ZsZ+Th9M6vK7RR2UnRu4KNVdvJ6JlKzql43K8DNHk2D3LTKdp2/jRJV1a9O0yBXBLg5aO5n5hfjYPK9KrcDgM4NfKCw++9Hy+U7Obh0J7fK7Vzs7fB4mJdW2aGUe8jIK65y21BvR9TzddHcL1Gpsev8Hb3i5WcEPyMepOszwpT0/f62q/CRR0xRURGOHDmCCRMmaMqkUimio6ORlKS7qyYpKQlxcXFaZTExMdi0aVOFxyksLERh4X/Tw7Oyqv4AfZTIpBK9PnyrYqejRSjYtRaQDhRKT+HNXU3KPa6wk2qNU1ILAopK9FtzyF4u07pfrFJDpUerk1QqgeKhWAtL1NAnv5fLpFpdgIIAFJZU/eEIAEq5DA82wKnUgl6XHpFIJFDaacdbpFLr1cImk0rKtdQV6PFhDpj2tSlRqVGiz2sjkWh9SQP6vzZ2MqnWgp4CgEI9n6vSTqpJzgFDXhtAaWfc+7A6r03596GAQmNfG7WAEmPfhyVqqI14bQDLvA/5GVGetX5GSCTALy/9gm5h3fQ6nqlZTSKUnp4OlUoFPz/tXwV+fn44d+6czm1SU1N11k9NTa3wOPHx8Zg6dWr1A7YQkyVCOvrZnqzXBt9fcoVakgUB5dcSKtT1/6Znd12+rkYCPbZVCcZvW6QG8PD/uZ7xFhh5TAHGx1siACUPb6tnvBZ5baqxbbG69GbodgBQYORztdRrU533obHxVue5Vue14WdE1duJ9TNCJeiXsJmD1SRCNWXChAlarUhZWVkIDg62YESGMdXMsYfXEQKAVyObwE56CAeuJEMqkUCpkMHeTgLZv1PtY5r4w8/1v6bVK3dz8ceFqpvMZRIJBjxeR6ts/6V0XLidU+W2tT0c0aWRr1bZT0ev69X9EhnqiYYPNCNn5Rdh0/GbVW4HAH1aBWl1v5y9lYVDKVV3Z7jY26Fvq1paZYln0/Rq4m/o54LIh7ozvt2foldXaNeGvqj1QBP/zYx87DibVvWGAF6NrKPV7H0k5R5O36q6pdTfzR5Phftrlf3fiRu4r0f3y2O1PbS6mvIKS7DxqH7LNvRqFqDVrXvpTjb2Xay6O8NBLsOLrbX/1/+8cAcpd6vuagrzdkKH+trdausOXkWRHi0AHep5I8zHWXM/PbsQW0/dqnI7AOjfOhjKB36N/309A8evZVS5naeTAk83D9Qq23bqFm7rcR3BZkFuaFXbQ3O/WKXG9wev6hUvPyP4GfGgBz8j/Jz8qqhtPlYzRqioqAiOjo7YuHEj+vTpoymPjY1FRkYG/ve//5Xbpnbt2oiLi8Po0aM1ZVOmTMGmTZtw4sQJvY5rbWOEktNzkaPzp4hhmgS66lypulilRn6xCi5KO63uByIiokeJvt/fVjNrTKFQICIiAomJiZoytVqNxMREREVF6dwmKipKqz4A/P777xXWtwWmWF1aIqn4ch1ymRSu9nImQUREZBOsqmssLi4OsbGxaN26Ndq2bYv58+cjNzcXQ4YMAQAMGjQIQUFBiI+PBwC8++676Ny5Mz7//HP06tUL69atw+HDh/HVV19Z8mmYlY4FoQ2ma+o8ERGRLbKqRKh///64c+cOJk+ejNTUVLRs2RLbtm3TDIi+evUqpA9kAu3atcPatWvx0Ucf4cMPP0T9+vWxadMmNG3a1FJPwexMsaiirvFBREREtshqxghZirWNEbqdVYC0rOpdHd7NQY7aXo5VVyQiInpE2dwYIdKPKcbuyNgiREREIsFEyMaYomtMbsprdRARET3CmAjZGFPMGtO1qjQREZEt4jeejTHFrDEOliYiIrFgImRjTDJrjF1jREQkEkyEbIwpLrFhZ4pmJSIiIivAbzwbY5LB0uwaIyIikWAiZGOqO1haKjXNFHwiIiJrwETIxkilElQnj+HlNYiISEz4rWeDqpMIcaA0ERGJCRMhG1SdcUJsESIiIjHht54Nqs44IVMMtiYiIrIWTIRskLQayQwXUyQiIjFhImSDqtMiJOcaQkREJCL81rNB1eneYosQERGJCRMhG1StrjG2CBERkYjwW88GVadrjC1CREQkJkyEbFB1GnW4jhAREYkJEyEbZOyFV2VSCS+vQUREosJEyAYZ2zXGi60SEZHYMBGyQcYOluZiikREJDZMhGyQsQkNEyEiIhIbJkI2yNiuMWPHFhEREVkrJkI2yNhZY2wRIiIisWEiZIOMbRFiIkRERGLDRMgGGZvQsGeMiIjEhomQDZJIJEYlNdVZkZqIiMgaMRGyUcYMfGbXGBERiQ0TIRtlTFJTnYu1EhERWSMmQjZKZsQry64xIiISGyZCNkpmxBx6do0REZHYMBGyUca07nBBRSIiEhsmQjbKmEUV2SJERERiw0TIRhk1WJp5EBERiYzVJEL37t3DgAED4OrqCnd3dwwdOhQ5OTmV1n/77bfRsGFDODg4oHbt2njnnXeQmZlZg1FbjqFdYxJJ6fpDREREYmI1idCAAQNw+vRp/P7779i8eTP+/PNPDB8+vML6N2/exM2bNzFnzhycOnUKq1atwrZt2zB06NAajNpyDJ0Kz24xIiISI4kgCIKlg6jK2bNnER4ejkOHDqF169YAgG3btqFnz564fv06AgMD9drPDz/8gFdffRW5ubmws7PTa5usrCy4ubkhMzMTrq6uRj+HmnY/twjX7+frXV8pl6KBn4sZIyIiIqo5+n5/W0WLUFJSEtzd3TVJEABER0dDKpXiwIEDeu+n7GRUlgQVFhYiKytL62aNDG0R4owxIiISI6tIhFJTU+Hr66tVZmdnB09PT6Smpuq1j/T0dEyfPr3S7jQAiI+Ph5ubm+YWHBxsdNyWZGhXF7vGiIhIjCyaCI0fP/7fC4RWfDt37ly1j5OVlYVevXohPDwcH3/8caV1J0yYgMzMTM3t2rVr1T6+JRg6WJp5EBERiZF+A2XM5L333sPgwYMrrRMWFgZ/f3/cvn1bq7ykpAT37t2Dv79/pdtnZ2eje/fucHFxwc8//wy5XF5pfaVSCaVSqVf8jzJDW3jYNUZERGJk0UTIx8cHPj4+VdaLiopCRkYGjhw5goiICADAzp07oVarERkZWeF2WVlZiImJgVKpxC+//AJ7e3uTxf6oY9cYERFR1axijFDjxo3RvXt3DBs2DAcPHsS+ffswatQovPTSS5oZYzdu3ECjRo1w8OBBAKVJ0FNPPYXc3FwsX74cWVlZSE1NRWpqKlQqlSWfTo0wNK9hIkRERGJk0RYhQ6xZswajRo1Ct27dIJVK8fzzz2PhwoWax4uLi3H+/Hnk5eUBAI4ePaqZUVavXj2tfSUnJyMkJKTGYrcEiUQCqRRQq/Wrz64xIiISI6tYR8iSrHUdIQA4l5qF4hL9Xt4gDwd4OinMHBEREVHNsKl1hMg4hswcM+Zq9URERNaOiZANM2RRRWOuVk9ERGTt+PVnwwxp5eEYISIiEiMmQjbMkJlgnDVGRERixETIhhnUNcYWISIiEiEmQjbMoMHSbBEiIiIRYiJkwwwZAM08iIiIxIiJkA2z0zMTkkpLF2AkIiISGyZCNkzfrjGODyIiIrFiImTD9O0a4/ggIiISKyZCNkzfBIctQkREJFZMhGyYvgkOW4SIiEismAjZMH0THF5njIiIxIqJkA3Te7A03wVERCRS/Aq0YVKpBPrkQuwaIyIisWIiZOP0GSfEwdJERCRWTIRsnD6tPUyEiIhIrJgI2TiZHq8wu8aIiEismAjZOH1aezhrjIiIxIqJkI3Tq2uM7wIiIhIpfgXaOH0SIXaNERGRWDERsnEcLE1ERFQxJkI2Tp/xP2wRIiIisWIiZOOkbBEiIiKqEBMhG8cWISIioooxEbJxVbUIccYYERGJGb8GbVxVrT1sDSIiIjFjImTjquoa42KKREQkZkyEbFxVXV/6DKYmIiKyVUyEbFxVLT6cMUZERGLGRMjGVTlGiIkQERGJGBMhGyeRSFBZrsNZY0REJGb8GhQBO1nFmRBnjRERkZgxERKByrq/2DVGRERiZjWJ0L179zBgwAC4urrC3d0dQ4cORU5Ojl7bCoKAHj16QCKRYNOmTeYN9BFU2cwwzhojIiIxs5pEaMCAATh9+jR+//13bN68GX/++SeGDx+u17bz58+HRMQtH2wRIiIi0s3O0gHo4+zZs9i2bRsOHTqE1q1bAwAWLVqEnj17Ys6cOQgMDKxw2+PHj+Pzzz/H4cOHERAQUFMhP1IqGwfEFiEiIhIzq2gRSkpKgru7uyYJAoDo6GhIpVIcOHCgwu3y8vLwyiuvYMmSJfD399frWIWFhcjKytK6WbtKu8aYBxERkYhZRSKUmpoKX19frTI7Ozt4enoiNTW1wu3GjBmDdu3a4dlnn9X7WPHx8XBzc9PcgoODjY77UVFp1xgzISIiEjGLJkLjx4//d52bim/nzp0zat+//PILdu7cifnz5xu03YQJE5CZmam5Xbt2zajjP0oqWyuIK0sTEZGYWXSM0HvvvYfBgwdXWicsLAz+/v64ffu2VnlJSQnu3btXYZfXzp07cenSJbi7u2uVP//88+jYsSN2796tczulUgmlUqnvU7AKbBEiIiLSzaKJkI+PD3x8fKqsFxUVhYyMDBw5cgQREREAShMdtVqNyMhInduMHz8er7/+ulZZs2bNMG/ePPTu3bv6wVuRypIdzhojIiIxs4pZY40bN0b37t0xbNgwJCQkoLi4GKNGjcJLL72kmTF248YNdOvWDd988w3atm0Lf39/na1FtWvXRmhoaE0/BYuqaLC0RMJZY0REJG5WMVgaANasWYNGjRqhW7du6NmzJzp06ICvvvpK83hxcTHOnz+PvLw8C0b5aLKrJBEiIiISM6toEQIAT09PrF27tsLHQ0JCIAhCpfuo6nFbVdGAaI4PIiIisbOaFiEyXkUJD8cHERGR2DEREoGKEh6ODyIiIrFjIiQCUqlE53ggtggREZHYMRESCV3jhDhGiIiIxI6JkEjoSnrYNUZERGLHREgkZDpeaXaNERGR2DEREgldXWNsECIiIrFjIiQS7BojIiIqj4mQSOgcLM2uMSIiEjkmQiLBFiEiIqLymAiJhK7rjXH6PBERiZ3VXGuMqufB1h+pFHC1l0NpxzyYiIjEjYmQSMilUrg7yuHmKIeL0g4Sjg8iIiJiIiQWbv8mQURERPQf9o0QERGRaDERIiIiItFiIkRERESixUSIiIiIRIuJEBEREYkWEyEiIiISLSZCREREJFpMhIiIiEi0mAgRERGRaDERIiIiItFiIkRERESixUSIiIiIRIuJEBEREYkWEyEiIiISLSZCREREJFp2lg7gUScIAgAgKyvLwpEQERGRvsq+t8u+xyvCRKgK2dnZAIDg4GALR0JERESGys7OhpubW4WPS4SqUiWRU6vVuHnzJlxcXCCRSEy236ysLAQHB+PatWtwdXU12X6pPJ7rmsHzXDN4nmsGz3PNMOd5FgQB2dnZCAwMhFRa8UggtghVQSqVolatWmbbv6urK//JagjPdc3gea4ZPM81g+e5ZpjrPFfWElSGg6WJiIhItJgIERERkWgxEbIQpVKJKVOmQKlUWjoUm8dzXTN4nmsGz3PN4HmuGY/CeeZgaSIiIhIttggRERGRaDERIiIiItFiIkRERESixUSIiIiIRIuJkBktWbIEISEhsLe3R2RkJA4ePFhp/R9++AGNGjWCvb09mjVrhq1bt9ZQpNbNkPO8bNkydOzYER4eHvDw8EB0dHSVrwv9x9D3dJl169ZBIpGgT58+5g3QRhh6njMyMjBy5EgEBARAqVSiQYMG/PzQg6Hnef78+WjYsCEcHBwQHByMMWPGoKCgoIaitU5//vknevfujcDAQEgkEmzatKnKbXbv3o3HHnsMSqUS9erVw6pVq8wbpEBmsW7dOkGhUAgrVqwQTp8+LQwbNkxwd3cX0tLSdNbft2+fIJPJhFmzZglnzpwRPvroI0EulwsnT56s4citi6Hn+ZVXXhGWLFkiHDt2TDh79qwwePBgwc3NTbh+/XoNR259DD3XZZKTk4WgoCChY8eOwrPPPlszwVoxQ89zYWGh0Lp1a6Fnz57C3r17heTkZGH37t3C8ePHazhy62LoeV6zZo2gVCqFNWvWCMnJycL27duFgIAAYcyYMTUcuXXZunWrMHHiROGnn34SAAg///xzpfUvX74sODo6CnFxccKZM2eERYsWCTKZTNi2bZvZYmQiZCZt27YVRo4cqbmvUqmEwMBAIT4+Xmf9fv36Cb169dIqi4yMFN544w2zxmntDD3PDyspKRFcXFyE1atXmytEm2HMuS4pKRHatWsnfP3110JsbCwTIT0Yep6XLl0qhIWFCUVFRTUVok0w9DyPHDlS6Nq1q1ZZXFyc0L59e7PGaUv0SYTef/99oUmTJlpl/fv3F2JiYswWF7vGzKCoqAhHjhxBdHS0pkwqlSI6OhpJSUk6t0lKStKqDwAxMTEV1ifjzvPD8vLyUFxcDE9PT3OFaROMPdfTpk2Dr68vhg4dWhNhWj1jzvMvv/yCqKgojBw5En5+fmjatClmzJgBlUpVU2FbHWPOc7t27XDkyBFN99nly5exdetW9OzZs0ZiFgtLfBfyoqtmkJ6eDpVKBT8/P61yPz8/nDt3Tuc2qampOuunpqaaLU5rZ8x5ftgHH3yAwMDAcv94pM2Yc713714sX74cx48fr4EIbYMx5/ny5cvYuXMnBgwYgK1bt+LixYt46623UFxcjClTptRE2FbHmPP8yiuvID09HR06dIAgCCgpKcGIESPw4Ycf1kTIolHRd2FWVhby8/Ph4OBg8mOyRYhEa+bMmVi3bh1+/vln2NvbWzocm5KdnY2BAwdi2bJl8Pb2tnQ4Nk2tVsPX1xdfffUVIiIi0L9/f0ycOBEJCQmWDs2m7N69GzNmzMAXX3yBo0eP4qeffsKWLVswffp0S4dG1cQWITPw9vaGTCZDWlqaVnlaWhr8/f11buPv729QfTLuPJeZM2cOZs6ciR07dqB58+bmDNMmGHquL126hJSUFPTu3VtTplarAQB2dnY4f/486tata96grZAx7+mAgADI5XLIZDJNWePGjZGamoqioiIoFAqzxmyNjDnPkyZNwsCBA/H6668DAJo1a4bc3FwMHz4cEydOhFTKdgVTqOi70NXV1SytQQBbhMxCoVAgIiICiYmJmjK1Wo3ExERERUXp3CYqKkqrPgD8/vvvFdYn484zAMyaNQvTp0/Htm3b0Lp165oI1eoZeq4bNWqEkydP4vjx45rbM888gyeeeALHjx9HcHBwTYZvNYx5T7dv3x4XL17UJJoAcOHCBQQEBDAJqoAx5zkvL69cslOWfAq8ZKfJWOS70GzDsEVu3bp1glKpFFatWiWcOXNGGD58uODu7i6kpqYKgiAIAwcOFMaPH6+pv2/fPsHOzk6YM2eOcPbsWWHKlCmcPq8HQ8/zzJkzBYVCIWzcuFG4deuW5padnW2pp2A1DD3XD+OsMf0Yep6vXr0quLi4CKNGjRLOnz8vbN68WfD19RU++eQTSz0Fq2DoeZ4yZYrg4uIifP/998Lly5eF3377Tahbt67Qr18/Sz0Fq5CdnS0cO3ZMOHbsmABAmDt3rnDs2DHhypUrgiAIwvjx44WBAwdq6pdNnx83bpxw9uxZYcmSJZw+b80WLVok1K5dW1AoFELbtm2F/fv3ax7r3LmzEBsbq1V/w4YNQoMGDQSFQiE0adJE2LJlSw1HbJ0MOc916tQRAJS7TZkypeYDt0KGvqcfxERIf4ae57/++kuIjIwUlEqlEBYWJnz66adCSUlJDUdtfQw5z8XFxcLHH38s1K1bV7C3txeCg4OFt956S7h//37NB25Fdu3apfMzt+zcxsbGCp07dy63TcuWLQWFQiGEhYUJK1euNGuMEkFgmx4RERGJE8cIERERkWgxESIiIiLRYiJEREREosVEiIiIiESLiRARERGJFhMhIiIiEi0mQkRERCRaTISIiIhItJgIERERkWgxESIiIiLRYiJERKJy584d+Pv7Y8aMGZqyv/76CwqFotxVr4nI9vFaY0QkOlu3bkWfPn3w119/oWHDhmjZsiWeffZZzJ0719KhEVENYyJERKI0cuRI7NixA61bt8bJkydx6NAhKJVKS4dFRDWMiRARiVJ+fj6aNm2Ka9eu4ciRI2jWrJmlQyIiC+AYISISpUuXLuHmzZtQq9VISUmxdDhEZCFsESIi0SkqKkLbtm3RsmVLNGzYEPPnz8fJkyfh6+tr6dCIqIYxESIi0Rk3bhw2btyIEydOwNnZGZ07d4abmxs2b95s6dCIqIaxa4yIRGX37t2YP38+vv32W7i6ukIqleLbb7/Fnj17sHTpUkuHR0Q1jC1CREREJFpsESIiIiLRYiJEREREosVEiIiIiESLiRARERGJFhMhIiIiEi0mQkRERCRaTISIiIhItJgIERERkWgxESIiIiLRYiJEREREosVEiIiIiESLiRARERGJ1v8DWVNZgBVaNfkAAAAASUVORK5CYII=",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "t_idx = 2\n",
    "parameter_idx = 0\n",
    "with torch.no_grad():\n",
    "    plt.ylabel(\"u(x,t={t:.2f})\".format(t=t[slice(*tpred)][t_idx]))\n",
    "    plt.title(\"Learning Heat Equation for parameter = {k:.2f}\".format(k = x_ood_test[parameter_idx,0,0,0]))\n",
    "    plt.xlabel(\"x\")\n",
    "    plt.plot(grid, mu[parameter_idx,:,t_idx,0], '--', lw=2, label = \"predicted $\\mu$ and $\\pm 3\\sigma$ (varFNO)\")\n",
    "    plt.fill_between(grid, mu[parameter_idx,:,t_idx,0]+3*std[parameter_idx,:,t_idx,0], mu[parameter_idx,:,t_idx,0]-3*std[parameter_idx,:,t_idx,0], alpha=0.2)\n",
    "    plt.plot(grid, y_ood_test[parameter_idx,:,t_idx,0], color = \"green\", label = \"true\")\n",
    "    plt.legend(loc=\"upper right\")\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 147,
   "id": "70f6d09c-1fb6-4295-b33b-4000e11722d0",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(2.2588)"
      ]
     },
     "execution_count": 147,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    " x_ood_test[parameters_idx, 0, 0, 0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 148,
   "id": "4f6352ec-3b12-445d-837e-e619b6815350",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(2.8865)"
      ]
     },
     "execution_count": 148,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x_ood_test[-1, 0, 0, 0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 149,
   "id": "a2b9fc32-f85d-4ea2-80b9-2019a23d48da",
   "metadata": {},
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'new_mu' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[149], line 13\u001b[0m\n\u001b[1;32m     11\u001b[0m param_val \u001b[38;5;241m=\u001b[39m x_ood_test[parameters_idx, \u001b[38;5;241m0\u001b[39m, \u001b[38;5;241m0\u001b[39m, \u001b[38;5;241m0\u001b[39m]\u001b[38;5;241m.\u001b[39mitem()\n\u001b[1;32m     12\u001b[0m true_vals \u001b[38;5;241m=\u001b[39m y_ood_test[parameters_idx, :, t_idx, \u001b[38;5;241m0\u001b[39m]\n\u001b[0;32m---> 13\u001b[0m pred_mu_conserv \u001b[38;5;241m=\u001b[39m \u001b[43mnew_mu\u001b[49m[parameters_idx, :, t_idx, \u001b[38;5;241m0\u001b[39m]\n\u001b[1;32m     14\u001b[0m pred_std_conserv \u001b[38;5;241m=\u001b[39m new_std[parameters_idx, :, t_idx, \u001b[38;5;241m0\u001b[39m]\n\u001b[1;32m     15\u001b[0m pred_mu_inf \u001b[38;5;241m=\u001b[39m u_proj_reshaped[parameters_idx, :, t_idx, \u001b[38;5;241m0\u001b[39m]\n",
      "\u001b[0;31mNameError\u001b[0m: name 'new_mu' is not defined"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<Figure size 1350x840 with 0 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "\n",
    "t_idx = 10\n",
    "\n",
    "for parameters_idx in [0,-1]:\n",
    "    with torch.no_grad():\n",
    "        # Set compact figure size and high DPI\n",
    "        fig = plt.figure(figsize=(4.5, 2.8), dpi=300)\n",
    "\n",
    "        time_label = t[slice(*tpred)][t_idx].item()\n",
    "        param_val = x_ood_test[parameters_idx, 0, 0, 0].item()\n",
    "        true_vals = y_ood_test[parameters_idx, :, t_idx, 0]\n",
    "        pred_mu_conserv = new_mu[parameters_idx, :, t_idx, 0]\n",
    "        pred_std_conserv = new_std[parameters_idx, :, t_idx, 0]\n",
    "        pred_mu_inf = u_proj_reshaped[parameters_idx, :, t_idx, 0]\n",
    "        pred_std_inf = stds[parameters_idx, :, t_idx, 0]\n",
    "\n",
    "        pred_mu_e2e = mu[parameters_idx, :, t_idx, 0]\n",
    "        pred_std_e2e = std[parameters_idx, :, t_idx, 0]\n",
    "\n",
    "\n",
    "        # Plot true solution\n",
    "        plt.plot(grid, true_vals, color=\"black\", lw=3.5, label=\"True\", zorder=3)\n",
    "\n",
    "        # ProbConserv\n",
    "        plt.plot(grid, pred_mu_conserv, '--', lw=2.5, color=\"#ef233c\",\n",
    "                 label=r\"ProbConserv\", zorder=2)\n",
    "        plt.fill_between(grid,\n",
    "                         pred_mu_conserv + 3 * pred_std_conserv,\n",
    "                         pred_mu_conserv - 3 * pred_std_conserv,\n",
    "                         color=\"#ef233c\", alpha=0.3, label=\"_nolegend_\", zorder=1)\n",
    "\n",
    "        # ProbHardE2E\n",
    "        plt.plot(grid, pred_mu_inf, '-.', lw=2.5, color=\"#3a86ff\",\n",
    "                 label=r\"ProbHardInference\", zorder=4)\n",
    "        plt.fill_between(grid,\n",
    "                         pred_mu_inf + 3 * pred_std_inf,\n",
    "                         pred_mu_inf - 3 * pred_std_inf,\n",
    "                         color=\"#3a86ff\", alpha=0.3, label=\"_nolegend_\", zorder=3)\n",
    "\n",
    "        # ProbHardE2E\n",
    "        plt.plot(grid, pred_mu_e2e, ':', lw=2.5, color=\"#06d6a0\",\n",
    "                 label=r\"ProbHardE2E\", zorder=4)\n",
    "        plt.fill_between(grid,\n",
    "                         pred_mu_e2e + 3 * pred_std_e2e,\n",
    "                         pred_mu_e2e - 3 * pred_std_e2e,\n",
    "                         color=\"#06d6a0\", alpha=0.2, label=\"_nolegend_\", zorder=4)\n",
    "\n",
    "        # Labels\n",
    "        plt.xlabel(\"x\", fontsize=10)\n",
    "        plt.ylabel(r\"$u(x, t={:.2f})$\".format(time_label), fontsize=10)\n",
    "\n",
    "        # Ticks and grid\n",
    "        plt.xticks(fontsize=10)\n",
    "        plt.yticks(fontsize=10)\n",
    "        plt.grid(True, linestyle=\"--\", alpha=0.4)\n",
    "        plt.xlim((0.45,0.58))\n",
    "        plt.ylim((-0.15,0.7))\n",
    "\n",
    "        # Legend\n",
    "        plt.legend(fontsize=10, loc=\"upper right\", frameon=False)\n",
    "\n",
    "        # Layout for Overleaf\n",
    "        plt.tight_layout(pad=0.3)\n",
    "        plt.savefig(\"nonlinear_PME_CRPS_m\" + str(param_val) + \".pdf\")\n",
    "        plt.show()\n",
    "        print(\"MSE error ProbHardE2E:\", torch.nn.MSELoss()(true_vals, pred_mu_e2e).item())\n",
    "        print(\"MSE error ProbConserv:\", torch.nn.MSELoss()(true_vals, pred_mu_conserv).item())\n",
    "        print(\"MSE error ProbHardInference:\", torch.nn.MSELoss()(true_vals, pred_mu_inf).item())\n",
    "        \n",
    "        \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7655f00d-968f-4904-9eeb-ff0b55c15af3",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "t_idx = 10\n",
    "\n",
    "for parameters_idx in [0, -1]:\n",
    "    with torch.no_grad():\n",
    "        # Compact figure\n",
    "        fig, ax = plt.subplots(figsize=(4.5, 2.8), dpi=300)\n",
    "\n",
    "        time_label = t[slice(*tpred)][t_idx].item()\n",
    "        param_val = x_ood_test[parameters_idx, 0, 0, 0].item()\n",
    "\n",
    "        true_vals        = y_ood_test[parameters_idx, :, t_idx, 0]\n",
    "        pred_mu_conserv  = new_mu[parameters_idx, :, t_idx, 0]\n",
    "        pred_std_conserv = new_std[parameters_idx, :, t_idx, 0]\n",
    "        pred_mu_inf      = u_proj_reshaped[parameters_idx, :, t_idx, 0]\n",
    "        pred_std_inf     = stds[parameters_idx, :, t_idx, 0]\n",
    "        pred_mu_e2e      = mu[parameters_idx, :, t_idx, 0]\n",
    "        pred_std_e2e     = std[parameters_idx, :, t_idx, 0]\n",
    "\n",
    "        # 1) True solution\n",
    "        l_true, = ax.plot(grid, true_vals.cpu(),\n",
    "                          color=\"black\", lw=3.5)\n",
    "\n",
    "        # 2) ProbConserv (red, '*' marker)\n",
    "        l_pc, = ax.plot(grid, pred_mu_conserv.cpu(),\n",
    "                        '-*', color=\"#ef233c\", lw=2.5)\n",
    "        s_pc = ax.fill_between(grid,\n",
    "                               pred_mu_conserv + 3*pred_std_conserv,\n",
    "                               pred_mu_conserv - 3*pred_std_conserv,\n",
    "                               color=\"#ef233c\", alpha=0.2)\n",
    "\n",
    "        # 3) ProbHardInference (blue, '^' marker)\n",
    "        l_ph_inf, = ax.plot(grid, pred_mu_inf.cpu(),\n",
    "                            '-^', color=\"#3a86ff\", lw=2.5)\n",
    "        s_ph_inf = ax.fill_between(grid,\n",
    "                                   pred_mu_inf + 3*pred_std_inf,\n",
    "                                   pred_mu_inf - 3*pred_std_inf,\n",
    "                                   color=\"#3a86ff\", alpha=0.2)\n",
    "\n",
    "        # 4) ProbHardE2E (purple, 'o' marker)\n",
    "        l_ph_e2e, = ax.plot(grid, pred_mu_e2e.cpu(),\n",
    "                            '-o', color=\"#BC7FF7\", lw=2.2)\n",
    "        s_ph_e2e = ax.fill_between(grid,\n",
    "                                   pred_mu_e2e + 3*pred_std_e2e,\n",
    "                                   pred_mu_e2e - 3*pred_std_e2e,\n",
    "                                   color=\"#BC7FF7\", alpha=0.2)\n",
    "\n",
    "        # Legend: group line + shade into single entries\n",
    "        handles = [\n",
    "            l_true,\n",
    "            (s_pc,    l_pc),\n",
    "            (s_ph_inf, l_ph_inf),\n",
    "            (s_ph_e2e, l_ph_e2e),\n",
    "        ]\n",
    "        labels = [\n",
    "            \"True\",\n",
    "            \"ProbConserv ±3σ\",\n",
    "            \"ProbHardInf ±3σ\",\n",
    "            \"ProbHardE2E ±3σ\",\n",
    "        ]\n",
    "        ax.legend(handles, labels,\n",
    "                  fontsize=10, loc=\"upper right\", frameon=False)\n",
    "\n",
    "        # Labels, limits, grid\n",
    "        ax.set_xlabel(\"x\", fontsize=10)\n",
    "        ax.set_ylabel(r\"$u(x, t={:.2f})$\".format(time_label), fontsize=10)\n",
    "        ax.set_xlim(0.45, 0.58)\n",
    "        ax.set_ylim(-0.15, 0.7)\n",
    "        ax.grid(True, linestyle=\"--\", alpha=0.4)\n",
    "\n",
    "        plt.tight_layout(pad=0.3)\n",
    "        plt.savefig(f\"nonlinear_PME_CRPS_m{param_val}.pdf\")\n",
    "        plt.show()\n",
    "\n",
    "        # Report MSEs\n",
    "        mse = torch.nn.MSELoss()\n",
    "        print(\"MSE ProbHardE2E:       \", mse(true_vals, pred_mu_e2e).item())\n",
    "        print(\"MSE ProbConserv:       \", mse(true_vals, pred_mu_conserv).item())\n",
    "        print(\"MSE ProbHardInference: \", mse(true_vals, pred_mu_inf).item())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 150,
   "id": "3a4babe7-d69b-4f4e-92a3-4c19c8527702",
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_statistics(\n",
    "    model, \n",
    "    x_data, \n",
    "    y_data, \n",
    "    t, \n",
    "    tpred, \n",
    "    grid, \n",
    "    dataset_class, \n",
    "    apply_probconserv=False, \n",
    "    plot=False,\n",
    "    x_data_test=None, \n",
    "    y_data_test=None,\n",
    "    return_latex=False,\n",
    "    name=\"Model\"\n",
    "):\n",
    "    import torch\n",
    "    import utils\n",
    "    import probconserv\n",
    "    import matplotlib.pyplot as plt\n",
    "\n",
    "    device = next(model.parameters()).device\n",
    "    x_data = x_data.to(device)\n",
    "\n",
    "    with torch.no_grad():\n",
    "        out = model(x_data)\n",
    "\n",
    "    if isinstance(out, tuple):\n",
    "        mu, var = out[0].cpu(), out[1].cpu()\n",
    "        std = torch.sqrt(var)\n",
    "    else:\n",
    "        mu = out.cpu()\n",
    "        std = torch.zeros_like(mu)\n",
    "        var = torch.square(std)\n",
    "\n",
    "    x_cpu = x_data.cpu()\n",
    "    mass_rhs_func = dataset_class.get_mass_rhs_func(x=x_cpu)\n",
    "\n",
    "    if apply_probconserv:\n",
    "        # new_mu, new_std, _, mass_rhs = probconserv.apply_constraint(\n",
    "        #     mu=mu[:, :, :, 0],\n",
    "        #     std=std[:, :, :, 0],\n",
    "        #     mass_rhs_func=mass_rhs_func,\n",
    "        #     t=t,\n",
    "        #     tpred=tpred,\n",
    "        #     grid_train=grid,\n",
    "        #     precis_g=float('inf'),\n",
    "        #     second_deriv_alpha=None,\n",
    "        # ) \n",
    "        nf,nx,nt, _ = mu.shape\n",
    "        _mu = mu.view(nf, -1).to(device)\n",
    "        _var = var.view(nf, -1).to(device)\n",
    "        _m = x_data.view(nf, -1).to(device)\n",
    "        \n",
    "        # print(_m)\n",
    "        \n",
    "        u_proj, u_var = project_and_stats_orth(torch.relu(_mu), _var, _m, model.full_residual, max_iter=30)\n",
    "        \n",
    "        out = (u_proj.view(nf,nx,nt,1), u_var .view(nf,nx,nt,1))\n",
    "        \n",
    "        mu, var, = out[0].cpu(), out[1].cpu()\n",
    "\n",
    "        std = torch.sqrt(var)\n",
    "        var = torch.square(std)\n",
    "\n",
    "        t_sliced = t[slice(*tpred)]\n",
    "        ts = repeat(t_sliced, \"nt -> nf nt\", nf=mu.shape[0])\n",
    "        xs = repeat(grid, \"nx -> nf nx\", nf=mu.shape[0])\n",
    "        inputs = meshgrid(ts, xs)\n",
    "        cerr = (probconserv.get_empirical_mass_rhs(mu[:, :, :, 0]) - mass_rhs_func(inputs)).abs().sum(dim=-1)\n",
    "    else:\n",
    "        t_sliced = t[slice(*tpred)]\n",
    "        ts = repeat(t_sliced, \"nt -> nf nt\", nf=mu.shape[0])\n",
    "        xs = repeat(grid, \"nx -> nf nx\", nf=mu.shape[0])\n",
    "        inputs = meshgrid(ts, xs)\n",
    "        cerr = (probconserv.get_empirical_mass_rhs(mu[:, :, :, 0]) - mass_rhs_func(inputs)).abs().sum(dim=-1)\n",
    "\n",
    "    stats = utils.compute_all_metrics_avg((mu, var), y_data, {})\n",
    "    stats[\"nMeRCI_all\"] = utils.compute_nMeRCI(mu, var, y_data).item()\n",
    "    stats[\"rmsce_all\"] = utils.compute_rmsce(mu, var, y_data).item()\n",
    "    stats[\"cerr_by_example\"] = cerr.tolist()\n",
    "    stats[\"mcerr\"] = cerr.mean().item()\n",
    "\n",
    "    # --- Test dataset ---\n",
    "    test_stats = None\n",
    "    if x_data_test is not None and y_data_test is not None:\n",
    "        x_data_test = x_data_test.to(device)\n",
    "        with torch.no_grad():\n",
    "            test_out = model(x_data_test)\n",
    "\n",
    "        if isinstance(test_out, tuple):\n",
    "            mu_test, var_test = test_out[0].cpu(), test_out[1].cpu()\n",
    "            std_test = torch.sqrt(var_test)\n",
    "        else:\n",
    "            mu_test = test_out.cpu()\n",
    "            std_test = torch.zeros_like(mu_test)\n",
    "            var_test = torch.square(std_test)\n",
    "\n",
    "        x_test_cpu = x_data_test.cpu()\n",
    "        test_mass_rhs_func = dataset_class.get_mass_rhs_func(x=x_test_cpu)\n",
    "\n",
    "        if apply_probconserv:\n",
    "            nf,nx,nt, _ = mu_test.shape\n",
    "            _mu = mu_test.view(nf, -1).to(device)\n",
    "            _var = var_test.view(nf, -1).to(device)\n",
    "            _m = x_data_test.view(nf, -1).to(device)\n",
    "            \n",
    "            # print(_m)\n",
    "            \n",
    "            u_proj, u_var = project_and_stats_orth(torch.relu(_mu), _var, _m, model.full_residual, max_iter=30)\n",
    "            \n",
    "            out = (u_proj.view(nf,nx,nt,1), u_var .view(nf,nx,nt,1))\n",
    "            \n",
    "            mu_test, var_test, = out[0].cpu(), out[1].cpu()\n",
    "    \n",
    "            std_test = torch.sqrt(var_test)\n",
    "            var_test = torch.square(std_test)\n",
    "            t_sliced = t[slice(*tpred)]\n",
    "            ts = repeat(t_sliced, \"nt -> nf nt\", nf=mu_test.shape[0])\n",
    "            xs = repeat(grid, \"nx -> nf nx\", nf=mu_test.shape[0])\n",
    "            inputs = meshgrid(ts, xs)\n",
    "            cerr_test = (probconserv.get_empirical_mass_rhs(mu_test[:, :, :, 0]) - test_mass_rhs_func(inputs)).abs().sum(dim=-1)\n",
    "        else:\n",
    "            t_sliced = t[slice(*tpred)]\n",
    "            ts = repeat(t_sliced, \"nt -> nf nt\", nf=mu_test.shape[0])\n",
    "            xs = repeat(grid, \"nx -> nf nx\", nf=mu_test.shape[0])\n",
    "            inputs = meshgrid(ts, xs)\n",
    "            cerr_test = (probconserv.get_empirical_mass_rhs(mu_test[:, :, :, 0]) - test_mass_rhs_func(inputs)).abs().sum(dim=-1)\n",
    "\n",
    "        test_stats = utils.compute_all_metrics_avg((mu_test, var_test), y_data_test, {})\n",
    "        test_stats[\"nMeRCI_all\"] = utils.compute_nMeRCI(mu_test, var_test, y_data_test).item()\n",
    "        test_stats[\"rmsce_all\"] = utils.compute_rmsce(mu_test, var_test, y_data_test).item()\n",
    "        test_stats[\"cerr_by_example\"] = cerr_test.tolist()\n",
    "        test_stats[\"mcerr\"] = cerr_test.mean().item()\n",
    "\n",
    "    # --- Optional plot ---\n",
    "    if plot:\n",
    "        t_idx = 1\n",
    "        param_idx = 0\n",
    "        with torch.no_grad():\n",
    "            plt.ylabel(f\"u(x, t={t[slice(*tpred)][t_idx]:.2f})\")\n",
    "            plt.xlabel(\"x\")\n",
    "            plt.title(f\"Predicted vs True (param = {x_data[param_idx,0,0,0].item():.2f})\")\n",
    "            mu_plot = mu[param_idx, :, t_idx, 0]\n",
    "            std_plot = std[param_idx, :, t_idx, 0]\n",
    "            y_true_plot = y_data[param_idx, :, t_idx, 0]\n",
    "            plt.plot(grid, mu_plot, '--', lw=2, label=\"μ ± 3σ\")\n",
    "            plt.fill_between(grid, mu_plot + 3*std_plot, mu_plot - 3*std_plot, alpha=0.2)\n",
    "            plt.plot(grid, y_true_plot, color=\"green\", label=\"true\")\n",
    "            plt.legend()\n",
    "            plt.show()\n",
    "\n",
    "    # --- Optional LaTeX row ---\n",
    "    latex_row = None\n",
    "    if return_latex and test_stats:\n",
    "        latex_row = (\n",
    "            f\"{name} & \"\n",
    "            f\"{stats['mse']:.2E} & {stats['nMeRCI_all']:.2E} & {stats['rmsce_all']:.2E} & {stats['mcerr']:.2E} & {stats['crps']:.2E} & \"\n",
    "            f\"{test_stats['mse']:.2E} & {test_stats['nMeRCI_all']:.2E} & {test_stats['rmsce_all']:.2E} & {test_stats['mcerr']:.2E} & {test_stats['crps']:.2E} \\\\\\\\\"\n",
    "        )\n",
    "\n",
    "    return (stats, test_stats, latex_row) if return_latex else (stats, test_stats)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 151,
   "id": "5e892f52-ecfa-4592-88ef-08d411f65be9",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_stats, test_stats, latex = compute_statistics(\n",
    "    model,\n",
    "    x_train, y_train,\n",
    "    x_data_test=x_ood_test, \n",
    "    y_data_test=y_ood_test,\n",
    "    t=t, tpred=tpred, grid=grid,\n",
    "    dataset_class=dataset_class,\n",
    "    apply_probconserv=True,\n",
    "    plot=False,\n",
    "    return_latex=True,\n",
    "    name=\"ProbHardE2E\"\n",
    ")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 152,
   "id": "6119a5a8-6ac8-4d29-809a-25c3218a9c4b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import gc\n",
    "gc.collect()\n",
    "torch.cuda.empty_cache()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 153,
   "id": "1c2eb841-f9bc-4b49-81de-58620aee2a4a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "43.509071692824364"
      ]
     },
     "execution_count": 153,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_stats['mse']*1e6"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 154,
   "id": "fa62601e-6198-464e-99e7-5b2efa38e8eb",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "14.803041517734528"
      ]
     },
     "execution_count": 154,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_stats['crps']*1e4"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 155,
   "id": "1c54d2f0-cfae-4b5a-ac1f-6d323f0d3d00",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'ProbHardE2E & 3.98E-05 & 7.83E-01 & 2.02E-01 & 3.59E-02 & 1.55E-03 & 4.35E-05 & 5.07E-01 & 2.05E-01 & 3.34E-02 & 1.48E-03 \\\\\\\\'"
      ]
     },
     "execution_count": 155,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "latex"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f44a1aec-0c82-40ef-b406-13f57d0bb081",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3c9f4d06-2687-4058-8c02-941af51595a7",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.mean(torch.norm(vmap(full_residual)(torch.relu(new_mu.reshape(nf,-1)).to(device), _m), dim=0))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ad73c5e5-51d1-41a9-868b-aa0a101bd2b2",
   "metadata": {},
   "outputs": [],
   "source": [
    "x = train_loader.dataset.tensors[0]\n",
    "y = train_loader.dataset.tensors[1]\n",
    "mass_rhs_func = dataset_class.get_mass_rhs_func(x=x)\n",
    "new_mu, new_std, _, mass_rhs = probconserv.apply_constraint(\n",
    "    mu=mu[:, :, :, 0], \n",
    "    std=std[:, :, :, 0], \n",
    "    mass_rhs_func=mass_rhs_func, \n",
    "    t=t, \n",
    "    tpred=tpred, \n",
    "    grid_train=grid, \n",
    "    precis_g=np.inf,\n",
    "    second_deriv_alpha=None,\n",
    ")\n",
    "new_mu = new_mu[:, :, :, None]\n",
    "new_std = new_std[:, :, :, None]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4e1d8ff4-50fc-4cdc-b824-ea2e4d850060",
   "metadata": {},
   "outputs": [],
   "source": [
    "# t_idx = len(t[slice(*tpred)])//2\n",
    "t_idx = 1\n",
    "\n",
    "for parameters_idx in range(0, 1, 5):\n",
    "    with torch.no_grad():\n",
    "        plt.ylabel(\"u(x,t={t:.2f})\".format(t=t[slice(*tpred)][t_idx]))\n",
    "        plt.title(\"Learning {dataset} for parameter = {k:.2f}\".format(k = x_train[parameters_idx,0,0,0], dataset = dataset))\n",
    "        plt.xlabel(\"x\")\n",
    "        plt.plot(grid, new_mu[parameters_idx,:,t_idx,0], '--', lw=2, label = \"predicted $\\mu$ and $\\pm 3\\sigma$ (varFNO)\")\n",
    "        plt.fill_between(grid, new_mu[parameters_idx,:,t_idx,0]+3*new_std[parameters_idx,:,t_idx,0], new_mu[parameters_idx,:,t_idx,0]-3*new_std[parameters_idx,:,t_idx,0], alpha=0.2)\n",
    "        plt.plot(grid, y_train[parameters_idx,:,t_idx,0], color = \"green\", label = \"true\")\n",
    "        print(torch.norm(y_train[parameters_idx,:,t_idx,0] - new_mu[parameters_idx,:,t_idx,0]))\n",
    "        plt.legend()\n",
    "        # plt.ylim(-1.0,1.5)\n",
    "        plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a48dcabf-7ee7-4671-9050-8e8e5340a020",
   "metadata": {},
   "outputs": [],
   "source": [
    "e2e_stats_train = utils.compute_all_metrics_avg((mu, torch.square(std)), y_train, {})\n",
    "e2e_stats_train[\"nMeRCI_all\"] = utils.compute_nMeRCI(mu, torch.square(std), y_train).item()\n",
    "e2e_stats_train[\"rmsce_all\"] = utils.compute_rmsce(mu, torch.square(std), y_train).item()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1bf0a238-c09a-4d2d-b666-01b025ab0f3d",
   "metadata": {},
   "outputs": [],
   "source": [
    "e2e_probconserv_stats_train = utils.compute_all_metrics_avg((new_mu, torch.square(new_std)), y_train, {})\n",
    "e2e_probconserv_stats_train[\"nMeRCI_all\"] = utils.compute_nMeRCI(new_mu, torch.square(new_std), y_train).item()\n",
    "e2e_probconserv_stats_train[\"rmsce_all\"] = utils.compute_rmsce(new_mu, torch.square(new_std), y_train).item()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d75eddde-47c9-434f-ab24-e999e03cb312",
   "metadata": {},
   "outputs": [],
   "source": [
    "cerr = (probconserv.get_empirical_mass_rhs(mu[:, :,  :, 0]) - mass_rhs).abs().sum(dim=-1)\n",
    "new_cerr = (probconserv.get_empirical_mass_rhs(new_mu[:, :, :, 0]) - mass_rhs).abs().sum(dim=-1)\n",
    "\n",
    "e2e_stats_train[\"cerr_by_example\"] = cerr.tolist()\n",
    "e2e_stats_train[\"mcerr\"] = cerr.mean().item()\n",
    "e2e_probconserv_stats_train[\"cerr_by_example\"] = new_cerr.tolist()\n",
    "e2e_probconserv_stats_train[\"mcerr\"] = new_cerr.mean().item()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "50627394-b489-4944-a427-98f04dde6b82",
   "metadata": {},
   "outputs": [],
   "source": [
    "out = model(x_ood_test.to(device))\n",
    "\n",
    "x = ood_test_loader.dataset.tensors[0]\n",
    "y = ood_test_loader.dataset.tensors[1]\n",
    "mass_rhs_func = dataset_class.get_mass_rhs_func(x=x)\n",
    "if model.probconserv:\n",
    "    _mu, _var, = out[0].cpu(), out[1].cpu()\n",
    "    _std = torch.sqrt(_var)\n",
    "    mass_rhs_func = dataset_class.get_mass_rhs_func(x=x)\n",
    "    new_mu, new_std, _, mass_rhs = probconserv.apply_constraint(\n",
    "                                                    mu=_mu[:, :, :, 0], \n",
    "                                                    std=_std[:, :, :, 0], \n",
    "                                                    mass_rhs_func=mass_rhs_func, \n",
    "                                                    t=t, \n",
    "                                                    tpred=tpred, \n",
    "                                                    grid_train=grid, \n",
    "                                                    precis_g=np.inf,\n",
    "                                                    second_deriv_alpha=None,\n",
    "                                                    )\n",
    "    out = (new_mu.unsqueeze(-1), torch.square(new_std).unsqueeze(-1))\n",
    "\n",
    "mu, var, = out[0].cpu(), out[1].cpu()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3639a00a-4839-4317-98f9-549ba4ab4656",
   "metadata": {},
   "outputs": [],
   "source": [
    "std = torch.sqrt(var)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d1eddcf4-0349-4c22-819c-d8bf5fd81cdf",
   "metadata": {},
   "outputs": [],
   "source": [
    "t_idx = 1\n",
    "parameter_idx = 0\n",
    "with torch.no_grad():\n",
    "    plt.ylabel(\"u(x,t={t:.2f})\".format(t=t[slice(*tpred)][t_idx]))\n",
    "    plt.title(\"Learning Heat Equation for parameter = {k:.2f}\".format(k = x_ood_test[parameter_idx,0,0,0]))\n",
    "    plt.xlabel(\"x\")\n",
    "    plt.plot(grid, mu[parameter_idx,:,t_idx,0], '--', lw=2, label = \"predicted $\\mu$ and $\\pm 3\\sigma$ (varFNO)\")\n",
    "    plt.fill_between(grid, mu[parameter_idx,:,t_idx,0]+3*std[parameter_idx,:,t_idx,0], mu[parameter_idx,:,t_idx,0]-3*std[parameter_idx,:,t_idx,0], alpha=0.2)\n",
    "    plt.plot(grid, y_ood_test[parameter_idx,:,t_idx,:], color = \"green\", label = \"true\")\n",
    "    plt.legend(loc=\"upper right\")\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ee646529-35a6-40d2-898a-288fd1c8240b",
   "metadata": {},
   "outputs": [],
   "source": [
    "x = ood_test_loader.dataset.tensors[0]\n",
    "y = ood_test_loader.dataset.tensors[1]\n",
    "mass_rhs_func = dataset_class.get_mass_rhs_func(x=x)\n",
    "new_mu, new_std, _, mass_rhs = probconserv.apply_constraint(\n",
    "    mu=mu[:, :, :, 0], \n",
    "    std=std[:, :, :, 0], \n",
    "    mass_rhs_func=mass_rhs_func, \n",
    "    t=t, \n",
    "    tpred=tpred, \n",
    "    grid_train=grid, \n",
    "    precis_g=np.inf,\n",
    "    second_deriv_alpha=None,\n",
    ")\n",
    "new_mu = new_mu[:, :, :, None]\n",
    "new_std = new_std[:, :, :, None]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dd8c116a-f7e0-4af0-a6d3-b506dee7945b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# t_idx = len(t[slice(*tpred)])//2\n",
    "t_idx = 1\n",
    "\n",
    "for parameters_idx in range(0, 1, 5):\n",
    "    with torch.no_grad():\n",
    "        plt.ylabel(\"u(x,t={t:.2f})\".format(t=t[slice(*tpred)][t_idx]))\n",
    "        plt.title(\"Learning {dataset} for parameter = {k:.2f}\".format(k = x_ood_test[parameters_idx,0,0,0], dataset = dataset))\n",
    "        plt.xlabel(\"x\")\n",
    "        plt.plot(grid, new_mu[parameters_idx,:,t_idx,0], '--', lw=2, label = \"predicted $\\mu$ and $\\pm 3\\sigma$ (varFNO)\")\n",
    "        plt.fill_between(grid, new_mu[parameters_idx,:,t_idx,0]+3*new_std[parameters_idx,:,t_idx,0], new_mu[parameters_idx,:,t_idx,0]-3*new_std[parameters_idx,:,t_idx,0], alpha=0.2)\n",
    "        plt.plot(grid, y_ood_test[parameters_idx,:,t_idx,0], color = \"green\", label = \"true\")\n",
    "        print(torch.norm(y_ood_test[parameters_idx,:,t_idx,0] - new_mu[parameters_idx,:,t_idx,0]))\n",
    "        plt.legend()\n",
    "        # plt.ylim(-1.0,1.5)\n",
    "        plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "463e36f8-0ccd-4319-a639-b3bccdc998a3",
   "metadata": {},
   "outputs": [],
   "source": [
    "e2e_stats_test = utils.compute_all_metrics_avg((mu, torch.square(std)), y_ood_test, {})\n",
    "e2e_stats_test[\"nMeRCI_all\"] = utils.compute_nMeRCI(mu, torch.square(std), y_ood_test).item()\n",
    "e2e_stats_test[\"rmsce_all\"] = utils.compute_rmsce(mu, torch.square(std), y_ood_test).item()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "052adf56-3cac-4fd5-a5d8-83ca0e91a969",
   "metadata": {},
   "outputs": [],
   "source": [
    "e2e_probconserv_stats_test = utils.compute_all_metrics_avg((new_mu, torch.square(new_std)), y_ood_test, {})\n",
    "e2e_probconserv_stats_test[\"nMeRCI_all\"] = utils.compute_nMeRCI(new_mu, torch.square(new_std), y_ood_test).item()\n",
    "e2e_probconserv_stats_test[\"rmsce_all\"] = utils.compute_rmsce(new_mu, torch.square(new_std), y_ood_test).item()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ef17475e-fdca-446c-9b90-1f90beec49d0",
   "metadata": {},
   "outputs": [],
   "source": [
    "cerr = (probconserv.get_empirical_mass_rhs(mu[:, :,  :, 0]) - mass_rhs).abs().sum(dim=-1)\n",
    "new_cerr = (probconserv.get_empirical_mass_rhs(new_mu[:, :, :, 0]) - mass_rhs).abs().sum(dim=-1)\n",
    "\n",
    "e2e_stats_test[\"cerr_by_example\"] = cerr.tolist()\n",
    "e2e_stats_test[\"mcerr\"] = cerr.mean().item()\n",
    "e2e_probconserv_stats_test[\"cerr_by_example\"] = new_cerr.tolist()\n",
    "e2e_probconserv_stats_test[\"mcerr\"] = new_cerr.mean().item()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c8f87a91-2bad-44c5-9633-1f6a03f77d68",
   "metadata": {},
   "outputs": [],
   "source": [
    "from decimal import Decimal"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "82377954-2227-47c4-bbc3-7c1c95682b6c",
   "metadata": {},
   "outputs": [],
   "source": [
    " f\"{ucons_stats_train['mcerr']:.2}\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bd09fb69-7adb-4f8f-811a-45a934dc46d0",
   "metadata": {},
   "outputs": [],
   "source": [
    "e2e_stats_test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ffa40a8c-b09f-4f57-bbb9-67b241d88629",
   "metadata": {},
   "outputs": [],
   "source": [
    "def dump_to_latex(ucons_stats_train, ucons_stats_test,  probconserv_stats_train, probconserv_stats_test, e2e_stats_train, e2e_stats_test, e2e_probconserv_stats_train, e2e_probconserv_stats_test):\n",
    "    table_str = f\"\"\"\n",
    "    Unconstrained (VarianceNO) & {ucons_stats_train['mse']:.2E} & {ucons_stats_train['nMeRCI_all']:.2E} & {ucons_stats_train['rmsce_all']:.2E} & {ucons_stats_train['mcerr']:.2E} & {ucons_stats_train['crps']:.2E} & {ucons_stats_test['mse']:.2E} & {ucons_stats_test['nMeRCI_all']:.2E} & {ucons_stats_test['rmsce_all']:.2E} & {ucons_stats_test['mcerr']:.2E} & {ucons_stats_test['crps']:.2E} \\\\\\\\\n",
    "    \\\\texttt{{ProbConserv}} & {probconserv_stats_train['mse']:.2E} & {probconserv_stats_train['nMeRCI_all']:.2E} & {probconserv_stats_train['rmsce_all']:.2E} & {probconserv_stats_train['mcerr']:.2E} & {probconserv_stats_train['crps']:.2E} & {probconserv_stats_test['mse']:.2E} & {probconserv_stats_test['nMeRCI_all']:.2E} & {probconserv_stats_test['rmsce_all']:.2E} & {probconserv_stats_test['mcerr']:.2E} & {probconserv_stats_test['crps']:.2E} \\\\\\\\\n",
    "    \\\\ourmethod{{}} & {e2e_stats_train['mse']:.2E} & {e2e_stats_train['nMeRCI_all']:.2E} & {e2e_stats_train['rmsce_all']:.2E} & {e2e_stats_train['mcerr']:.2E} & {e2e_stats_train['crps']:.2E} & {e2e_stats_test['mse']:.2E} & {e2e_stats_test['nMeRCI_all']:.2E} & {e2e_stats_test['rmsce_all']:.2E} & {e2e_stats_test['mcerr']:.2E} & {e2e_stats_test['crps']:.2E} \\\\\\\\\n",
    "    \\\\ourmethod{{}} + \\\\texttt{{ProbConserv}} & {e2e_probconserv_stats_train['mse']:.2E} & {e2e_probconserv_stats_train['nMeRCI_all']:.2E} & {e2e_probconserv_stats_train['rmsce_all']:.2E} & {e2e_probconserv_stats_train['mcerr']:.2E} & {e2e_probconserv_stats_train['crps']:.2E} & {e2e_probconserv_stats_test['mse']:.2E} & {e2e_probconserv_stats_test['nMeRCI_all']:.2E} & {e2e_probconserv_stats_test['rmsce_all']:.2E} & {e2e_probconserv_stats_test['mcerr']:.2E} & {e2e_probconserv_stats_test['crps']:.2E} \\\\\\\\\n",
    "    \"\"\"\n",
    "    return table_str"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c95a912a-bd68-459a-abf7-1c5f03b5661a",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(dump_to_latex(ucons_stats_train, ucons_stats_test,  probconserv_stats_train, probconserv_stats_test, e2e_stats_train, e2e_stats_test, e2e_probconserv_stats_train, e2e_probconserv_stats_test))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "654cef07-235c-492e-9960-9b3d236949dd",
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_statistics(\n",
    "    model, \n",
    "    x_data, \n",
    "    y_data, \n",
    "    t, \n",
    "    tpred, \n",
    "    grid, \n",
    "    dataset_class, \n",
    "    apply_probconserv=False, \n",
    "    plot=False,\n",
    "    x_data_test=None, \n",
    "    y_data_test=None,\n",
    "    return_latex=False,\n",
    "    name=\"Model\"\n",
    "):\n",
    "    import torch\n",
    "    import utils\n",
    "    import probconserv\n",
    "    import matplotlib.pyplot as plt\n",
    "\n",
    "    device = next(model.parameters()).device\n",
    "    x_data = x_data.to(device)\n",
    "\n",
    "    with torch.no_grad():\n",
    "        out = model(x_data)\n",
    "\n",
    "    if isinstance(out, tuple):\n",
    "        mu, var = out[0].cpu(), out[1].cpu()\n",
    "        std = torch.sqrt(var)\n",
    "    else:\n",
    "        mu = out.cpu()\n",
    "        std = torch.zeros_like(mu)\n",
    "        var = torch.square(std)\n",
    "\n",
    "    x_cpu = x_data.cpu()\n",
    "    mass_rhs_func = dataset_class.get_mass_rhs_func(x=x_cpu)\n",
    "\n",
    "    if apply_probconserv:\n",
    "        new_mu, new_std, _, mass_rhs = probconserv.apply_constraint(\n",
    "            mu=mu[:, :, :, 0],\n",
    "            std=std[:, :, :, 0],\n",
    "            mass_rhs_func=mass_rhs_func,\n",
    "            t=t,\n",
    "            tpred=tpred,\n",
    "            grid_train=grid,\n",
    "            precis_g=float('inf'),\n",
    "            second_deriv_alpha=None,\n",
    "        )\n",
    "        mu = new_mu.unsqueeze(-1)\n",
    "        std = new_std.unsqueeze(-1)\n",
    "        var = torch.square(std)\n",
    "        cerr = (probconserv.get_empirical_mass_rhs(mu[:, :, :, 0]) - mass_rhs).abs().sum(dim=-1)\n",
    "    else:\n",
    "        cerr = (probconserv.get_empirical_mass_rhs(mu[:, :, :, 0]) - mass_rhs_func(rearrange(x_cpu, \"nf nx nt 1-> nf nt nx 1\"))).abs().sum(dim=-1)\n",
    "\n",
    "    stats = utils.compute_all_metrics_avg((mu, var), y_data, {})\n",
    "    stats[\"nMeRCI_all\"] = utils.compute_nMeRCI(mu, var, y_data).item()\n",
    "    stats[\"rmsce_all\"] = utils.compute_rmsce(mu, var, y_data).item()\n",
    "    stats[\"cerr_by_example\"] = cerr.tolist()\n",
    "    stats[\"mcerr\"] = cerr.mean().item()\n",
    "\n",
    "    # --- Test dataset ---\n",
    "    test_stats = None\n",
    "    if x_data_test is not None and y_data_test is not None:\n",
    "        x_data_test = x_data_test.to(device)\n",
    "        with torch.no_grad():\n",
    "            test_out = model(x_data_test)\n",
    "\n",
    "        if isinstance(test_out, tuple):\n",
    "            mu_test, var_test = test_out[0].cpu(), test_out[1].cpu()\n",
    "            std_test = torch.sqrt(var_test)\n",
    "        else:\n",
    "            mu_test = test_out.cpu()\n",
    "            std_test = torch.zeros_like(mu_test)\n",
    "            var_test = torch.square(std_test)\n",
    "\n",
    "        x_test_cpu = x_data_test.cpu()\n",
    "        test_mass_rhs_func = dataset_class.get_mass_rhs_func(x=x_test_cpu)\n",
    "\n",
    "        if apply_probconserv:\n",
    "            new_mu_test, new_std_test, _, test_mass_rhs = probconserv.apply_constraint(\n",
    "                mu=mu_test[:, :, :, 0],\n",
    "                std=std_test[:, :, :, 0],\n",
    "                mass_rhs_func=test_mass_rhs_func,\n",
    "                t=t,\n",
    "                tpred=tpred,\n",
    "                grid_train=grid,\n",
    "                precis_g=float('inf'),\n",
    "                second_deriv_alpha=None,\n",
    "            )\n",
    "            mu_test = new_mu_test.unsqueeze(-1)\n",
    "            std_test = new_std_test.unsqueeze(-1)\n",
    "            var_test = torch.square(std_test)\n",
    "            cerr_test = (probconserv.get_empirical_mass_rhs(mu_test[:, :, :, 0]) - test_mass_rhs).abs().sum(dim=-1)\n",
    "        else:\n",
    "            t_sliced = t[slice(*tpred)]\n",
    "            ts = repeat(t_sliced, \"nt -> nf nt\", nf=mu.shape[0])\n",
    "            xs = repeat(grid_train, \"nx -> nf nx\", nf=mu.shape[0])\n",
    "            inputs = meshgrid(ts, xs)\n",
    "            inputs = inputs.to(mu.device)\n",
    "            cerr_test = (probconserv.get_empirical_mass_rhs(mu_test[:, :, :, 0]) - test_mass_rhs_func(rearrange(x_test_cpu, \"nf nx nt 1-> nf nt nx 1\"))).abs().sum(dim=-1)\n",
    "\n",
    "        test_stats = utils.compute_all_metrics_avg((mu_test, var_test), y_data_test, {})\n",
    "        test_stats[\"nMeRCI_all\"] = utils.compute_nMeRCI(mu_test, var_test, y_data_test).item()\n",
    "        test_stats[\"rmsce_all\"] = utils.compute_rmsce(mu_test, var_test, y_data_test).item()\n",
    "        test_stats[\"cerr_by_example\"] = cerr_test.tolist()\n",
    "        test_stats[\"mcerr\"] = cerr_test.mean().item()\n",
    "\n",
    "    # --- Optional plot ---\n",
    "    if plot:\n",
    "        t_idx = 1\n",
    "        param_idx = 0\n",
    "        with torch.no_grad():\n",
    "            plt.ylabel(f\"u(x, t={t[slice(*tpred)][t_idx]:.2f})\")\n",
    "            plt.xlabel(\"x\")\n",
    "            plt.title(f\"Predicted vs True (param = {x_data[param_idx,0,0,0].item():.2f})\")\n",
    "            mu_plot = mu[param_idx, :, t_idx, 0]\n",
    "            std_plot = std[param_idx, :, t_idx, 0]\n",
    "            y_true_plot = y_data[param_idx, :, t_idx, 0]\n",
    "            plt.plot(grid, mu_plot, '--', lw=2, label=\"μ ± 3σ\")\n",
    "            plt.fill_between(grid, mu_plot + 3*std_plot, mu_plot - 3*std_plot, alpha=0.2)\n",
    "            plt.plot(grid, y_true_plot, color=\"green\", label=\"true\")\n",
    "            plt.legend()\n",
    "            plt.show()\n",
    "\n",
    "    # --- Optional LaTeX row ---\n",
    "    latex_row = None\n",
    "    if return_latex and test_stats:\n",
    "        latex_row = (\n",
    "            f\"{name} & \"\n",
    "            f\"{stats['mse']:.2E} & {stats['nMeRCI_all']:.2E} & {stats['rmsce_all']:.2E} & {stats['mcerr']:.2E} & {stats['crps']:.2E} & \"\n",
    "            f\"{test_stats['mse']:.2E} & {test_stats['nMeRCI_all']:.2E} & {test_stats['rmsce_all']:.2E} & {test_stats['mcerr']:.2E} & {test_stats['crps']:.2E} \\\\\\\\\"\n",
    "        )\n",
    "\n",
    "    return (stats, test_stats, latex_row) if return_latex else (stats, test_stats)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "47db742d-7441-4c5f-a60c-a7c6bfd76632",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_stats, test_stats, latex = compute_statistics(\n",
    "    model,\n",
    "    x_train, y_train,\n",
    "    x_data_test=x_ood_test, \n",
    "    y_data_test=y_ood_test,\n",
    "    t=t, tpred=tpred, grid=grid,\n",
    "    dataset_class=dataset_class,\n",
    "    apply_probconserv=True,\n",
    "    plot=False,\n",
    "    return_latex=True,\n",
    "    name=\"ProbConserv\"\n",
    ")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a6a1914c-1c99-4c23-9fe7-fafe7a2a5d7c",
   "metadata": {},
   "outputs": [],
   "source": [
    "latex"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "af0c8f58-4cae-495a-8cfb-ee81163a961c",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "optprobconserv2",
   "language": "python",
   "name": "optprobconserv2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
