{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "# Import Necessary Packages"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {
        "id": "GUrtasdlwqId"
      },
      "outputs": [],
      "source": [
        "# Standard Libraries\n",
        "import os\n",
        "import time, copy\n",
        "import math\n",
        "import random\n",
        "from datetime import datetime\n",
        "from collections import deque\n",
        "from itertools import chain\n",
        "from typing import Dict, Optional, Tuple\n",
        "\n",
        "# Scientific and Data Libraries\n",
        "import numpy as np\n",
        "import matplotlib.pyplot as plt\n",
        "import seaborn as sns\n",
        "from scipy import optimize\n",
        "from scipy.stats import norm\n",
        "\n",
        "# PyTorch\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "import torch.optim as optim\n",
        "from torch.nn.parameter import Parameter\n",
        "from torch.nn.modules.utils import _pair\n",
        "from torch.utils.data import DataLoader, TensorDataset\n",
        "import torch.nn.init as init\n",
        "from math import log\n",
        "\n",
        "\n",
        "# Environments\n",
        "import gym\n",
        "import gymnasium as gym\n",
        "from gymnasium import core, spaces\n",
        "from gymnasium.spaces import Box, Dict\n",
        "from gymnasium.wrappers import RescaleAction\n",
        "import dm_env\n",
        "from dm_control import suite\n",
        "\n",
        "# Argument Parsing\n",
        "import argparse\n",
        "\n",
        "\n",
        "# Argument parser setup\n",
        "parser = argparse.ArgumentParser()\n",
        "parser.add_argument(\"-f\", required=False)  # Allows running in Jupyter notebooks\n",
        "args, unknown = parser.parse_known_args()\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "# BNN Architucture\n",
        "\n",
        "This file consider all the funcitons for creating a Bayesian NN applied for model training."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "metadata": {},
      "outputs": [],
      "source": [
        "class VariationalBayesianLinear(nn.Module):\n",
        "    __constants__ = [\"in_features\", \"out_features\"]\n",
        "    in_features: int\n",
        "    out_features: int\n",
        "    weight: torch.Tensor\n",
        "\n",
        "    def __init__(\n",
        "        self,\n",
        "        in_features: int,\n",
        "        out_features: int,\n",
        "        bias: bool = True,\n",
        "        prior_log_sig2=0,\n",
        "        log_sig2_init=-4.6,\n",
        "    ) -> None:\n",
        "        super().__init__()\n",
        "        self.in_features = in_features\n",
        "        self.out_features = out_features\n",
        "        self.has_bias = bias\n",
        "        self.weight_mu = nn.Parameter(torch.empty((out_features, in_features)))\n",
        "        self.weight_log_sig2 = nn.Parameter(torch.empty((out_features, in_features)))\n",
        "        self.weight_mu_prior = nn.Parameter(\n",
        "            torch.zeros((out_features, in_features)), requires_grad=False\n",
        "        )\n",
        "        self.weight_log_sig2_prior = nn.Parameter(\n",
        "            prior_log_sig2 * torch.ones((out_features, in_features)),\n",
        "            requires_grad=False,\n",
        "        )\n",
        "        if self.has_bias:\n",
        "            self.bias_mu = nn.Parameter(torch.Tensor(out_features))\n",
        "        else:\n",
        "            self.register_parameter(\"bias_mu\", None)\n",
        "        self.reset_parameters(log_sig2_init)\n",
        "\n",
        "    def reset_parameters(self, log_sig2_init) -> None:\n",
        "        init.kaiming_uniform_(self.weight_mu, a=math.sqrt(self.weight_mu.shape[1]))\n",
        "        init.constant_(self.weight_log_sig2, log_sig2_init)\n",
        "        if self.has_bias:\n",
        "            init.zeros_(self.bias_mu)\n",
        "\n",
        "    def forward(self, input: torch.Tensor) -> torch.Tensor:\n",
        "        output_mu = F.linear(input, self.weight_mu, self.bias_mu)\n",
        "        output_sig2 = F.linear(\n",
        "            input.pow(2), self.weight_log_sig2.exp(), bias=None  # .clamp(-10, 10)\n",
        "        )\n",
        "        return output_mu + output_sig2.sqrt() * torch.randn_like(output_sig2)\n",
        "\n",
        "    def get_mean_var(self, input: torch.Tensor) -> torch.Tensor:\n",
        "        mu = F.linear(input, self.weight_mu, self.bias_mu)\n",
        "        sig2 = F.linear(input**2, self.weight_log_sig2.exp(), bias=None)\n",
        "        return mu, sig2\n",
        "\n",
        "    def extra_repr(self) -> str:\n",
        "        return \"in_features={}, out_features={}, bias={}\".format(\n",
        "            self.in_features, self.out_features, self.has_bias\n",
        "        )\n",
        "\n",
        "    def update_prior(self, newprior):\n",
        "        self.weight_mu_prior.data = newprior.weight_mu.data.clone()\n",
        "        self.weight_mu_prior.data.requires_grad = False\n",
        "        self.weight_log_sig2_prior.data = newprior.weight_log_sig2.data.clone()\n",
        "        self.weight_log_sig2_prior.data.requires_grad = False\n",
        "\n",
        "    def kl_loss(self):\n",
        "        kl_weight = 0.5 * (\n",
        "            self.weight_log_sig2_prior\n",
        "            - self.weight_log_sig2\n",
        "            + (\n",
        "                self.weight_log_sig2.exp()\n",
        "                + (self.weight_mu_prior - self.weight_mu) ** 2\n",
        "            )\n",
        "            / self.weight_log_sig2_prior.exp()\n",
        "            - 1.0\n",
        "        )\n",
        "        kl = kl_weight.sum()\n",
        "        n = len(self.weight_mu.view(-1))\n",
        "        return kl, n\n",
        "    \n",
        "\n",
        "    def kl_loss_informative(self, prior_layer):\n",
        "        \"\"\"KL loss comparing the posterior to an informative prior\"\"\"\n",
        "        # Get prior parameters from the corresponding layer in priornet\n",
        "        weight_mu_prior = prior_layer.weight_mu\n",
        "        weight_log_sig2_prior = prior_layer.weight_log_sig2  # log(σ²) prior\n",
        "\n",
        "\n",
        "        # Compute KL divergence\n",
        "        kl_weight = 0.5 * (\n",
        "            weight_log_sig2_prior\n",
        "            - self.weight_log_sig2\n",
        "            + (\n",
        "                self.weight_log_sig2.exp()\n",
        "                + (weight_mu_prior - self.weight_mu) ** 2\n",
        "            )\n",
        "            / weight_log_sig2_prior.exp()\n",
        "            - 1.0\n",
        "        )\n",
        "        \n",
        "        kl = kl_weight.sum()\n",
        "        n = len(self.weight_mu.view(-1))\n",
        "        \n",
        "        return kl, n\n",
        "    \n",
        "    \n",
        "def calculate_kl_terms_informative(model: nn.Module, priornet: nn.Module):\n",
        "    \"\"\"Function to calculate KL loss of Bayesian neural network with an informative prior\"\"\"\n",
        "    kl, n = 0, 0\n",
        "    for m, p in zip(model.modules(), priornet.modules()):  # Match layers\n",
        "        if m.__class__.__name__.startswith((\"Variational\")):\n",
        "            kl_, n_ = m.kl_loss_informative(p)  # Pass prior layer\n",
        "            kl += kl_\n",
        "            n += n_\n",
        "    return kl, n\n",
        "\n",
        "\n",
        "class ValueNetVBfull(nn.Module):\n",
        "    def __init__(self, dim_obs, local_reparam=True, n_hidden=256):\n",
        "        super().__init__()\n",
        "        self.arch = nn.Sequential(\n",
        "                VariationalBayesianLinear(dim_obs[0], n_hidden),\n",
        "                nn.ReLU(inplace=True),\n",
        "                ###\n",
        "                VariationalBayesianLinear(n_hidden, n_hidden),\n",
        "                nn.ReLU(inplace=True),\n",
        "                ###\n",
        "        )\n",
        "        self.head = VariationalBayesianLinear(n_hidden, 1)\n",
        "\n",
        "    def forward(self, x):\n",
        "        h = self.arch(x)\n",
        "        return self.head(h), self.head.get_mean_var(h)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "# Recursive bound calculation\n",
        "This file contains all funcitons necessary for bound calculation for all non-recursive and recursive baselines."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "metadata": {},
      "outputs": [],
      "source": [
        "\n",
        "def KL(Q, P):\n",
        "    \"\"\"\n",
        "    Compute Kullback-Leibler (KL) divergence between distributions Q and P.\n",
        "    \"\"\"\n",
        "    return sum([q * math.log(q / p) if q > 0.0 else 0.0 for q, p in zip(Q, P)])\n",
        "\n",
        "\n",
        "def KL_binomial(q, p):\n",
        "    \"\"\"\n",
        "    Compute the KL-divergence between two Bernoulli distributions of probability\n",
        "    of success q and p. That is, Q=(q,1-q), P=(p,1-p).\n",
        "    \"\"\"\n",
        "    return KL([q, 1.0 - q], [p, 1.0 - p])\n",
        "\n",
        "\n",
        "def solve_kl_sup(q, right_hand_side):\n",
        "    \"\"\"\n",
        "    find x such that:\n",
        "        kl( q || x ) = right_hand_side\n",
        "        x > q\n",
        "    \"\"\"\n",
        "\n",
        "    f = lambda x: KL_binomial(q, x) - right_hand_side\n",
        "    if f(1.0 - 1e-9) <= 0.0:\n",
        "        return 1.0 - 1e-9\n",
        "    else:\n",
        "        return optimize.brentq(f, q, 1.0 - 1e-11)\n",
        "\n",
        "\n",
        "\n",
        "def solve_kl_inf(q, right_hand_side):\n",
        "    \"\"\"\n",
        "    find x such that:\n",
        "        kl( q || x ) = right_hand_side\n",
        "        x < q\n",
        "    \"\"\"\n",
        " \n",
        "\n",
        "    f = lambda x: KL_binomial(q, x) - right_hand_side\n",
        "    if f(1e-9) <= 0.0:\n",
        "        return 1e-9\n",
        "    else:\n",
        "        return optimize.brentq(f, 1e-11, q)\n",
        "\n",
        "\n",
        "\n",
        "def get_loss(input, target, limit):\n",
        "    loss = (input - target) ** 2\n",
        "    loss = torch.clamp(loss, max=limit)\n",
        "    # emploss = loss / loss.max().item()\n",
        "    return loss\n",
        "\n",
        "\n",
        "def get_kl(mu1, var1, mu2, var2):\n",
        "    return var2.log() - var1.log() + (var1 + (mu1 - mu2) ** 2) / (2 * var2) - 0.5\n",
        "\n",
        "def kl_div_gaussian(eval_loader):\n",
        "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "    kl_values = []\n",
        "    for i, batch in enumerate(eval_loader):\n",
        "        kl_ = batch[0][5]\n",
        "        kl_values.append(kl_.item())\n",
        "    print(\"KL Divergence values per step:\", kl_values)\n",
        "    return kl_values\n",
        "\n",
        "\n",
        "\n",
        "################# compute pac-bayes loss terms seprately #################\n",
        "def mcsampling_excess(input_p, target_p, input, target, limit, gamma_t=0.5):\n",
        "    \"\"\"Compute the mean of excess loss delta_j^{\\hat}(h_2, h_1, X, Y), where\"\"\"\n",
        "    loss_prior = 0\n",
        "    loss_posterior = 0\n",
        "    loss_prior = get_loss(input_p, target_p, limit)\n",
        "    loss_posterior = get_loss(input, target, limit)\n",
        "    delta = get_excess(loss_posterior, loss_prior, gamma_t)\n",
        "    n_ = abs(loss_prior.shape[0] - loss_posterior.shape[0])\n",
        "    posteriorloss = loss_posterior.mean().item()\n",
        "    priorloss = loss_prior.mean().item()\n",
        "    \n",
        "\n",
        "    ex_emp = loss_posterior.mean() - gamma_t * (loss_prior[n_:]).mean()\n",
        "    if math.isclose(ex_emp.item(), delta.mean().item(), rel_tol=1e-3):\n",
        "        print(\"loss_excess avg is equal to (avg posteriorloss - gamma * avg priorloss)\")\n",
        "    else:\n",
        "        print( f\"ecxess_of_avg: {ex_emp.item()},excess of all: {delta.mean().item()}\")\n",
        "\n",
        "    return delta.cpu().numpy(), posteriorloss, priorloss\n",
        "\n",
        "\n",
        "def get_excess(loss_posterior, loss_prior, gamma_t):\n",
        "    \"\"\"Compute excess loss delta_j^{\\hat}(h_2, h_1, x, y) for a data and each j.\"\"\"\n",
        "    # added to handel cases where the episode length is not 1000\n",
        "    if loss_prior.shape != loss_posterior.shape:\n",
        "        n = abs(loss_prior.shape[0] - loss_posterior.shape[0])\n",
        "        if loss_prior.shape[0] > loss_posterior.shape[0]:\n",
        "            loss_prior = loss_prior[n:]\n",
        "        else:\n",
        "            raise ValueError(\"loss_posterior is greater than loss_prior\")\n",
        "        \n",
        "    delta = loss_posterior - gamma_t * loss_prior\n",
        "    return delta \n",
        "\n",
        "\n",
        "def compute_B_1(emp_loss, kl, T, n_bound, B, delta_test=0.01, delta=0.025, gamma_t = 0.5):\n",
        "    \"\"\"Compute B_1.\"\"\"\n",
        "    mu = 0 #B / 2\n",
        "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "\n",
        "    inv_1_sup = solve_kl_sup((torch.maximum(torch.tensor(0.0), emp_loss - mu)).mean().item(), np.log(T / delta_test) / n_bound)\n",
        "    kl_sup = solve_kl_sup(\n",
        "        inv_1_sup / (B - mu),  # not 1 it should be B\n",
        "        (kl + np.log((4 * T * np.sqrt(n_bound)) / delta)) / (n_bound),\n",
        "    )\n",
        "    inv_1_inf = solve_kl_inf((torch.maximum(torch.tensor(0.0), mu - emp_loss)).mean().item(), np.log(T / delta_test) / n_bound)\n",
        "    kl_inf = solve_kl_inf(\n",
        "        inv_1_inf / (mu + gamma_t*B ),  # change to mu + gamma_t\n",
        "        (kl + np.log((4 * T * np.sqrt(n_bound)) / delta)) / (n_bound),\n",
        "    )\n",
        "    B_1 = (B - mu) * kl_sup - (mu + gamma_t*B) * kl_inf \n",
        "    return B_1, inv_1_sup, B  * kl_sup \n",
        "    \n",
        "\n",
        "\n",
        "def compute_E_t(loss_excess, kl, T, gamma_t, n_bound, B, delta_test=0.01, delta=0.025):\n",
        "    \"\"\"Compute E_t.\"\"\"\n",
        "    mu = 0 #B / 2\n",
        "    val1 = (torch.maximum(torch.tensor(0.0), loss_excess - mu)).mean().item()\n",
        "    val2 = (torch.maximum(torch.tensor(0.0), mu - loss_excess)).mean().item()\n",
        "    if math.isclose(val1 - val2, loss_excess.mean().item(), rel_tol=1e-3):\n",
        "        print(\"loss_excess avg is equal to (exces sup - exces inf) up to 3 decimal places\")\n",
        "    else:\n",
        "        print( ((torch.maximum(torch.tensor(0.0), loss_excess - mu)).mean() - (torch.maximum(torch.tensor(0.0), mu - loss_excess)).mean()).item(),loss_excess.mean().item())\n",
        "        raise ValueError(\"loss_excess avg is not equal to (exces sup - exces inf)\")\n",
        "    \n",
        "    inv_1 = solve_kl_sup(\n",
        "        (torch.maximum(torch.tensor(0.0), loss_excess - mu)).mean().item(),\n",
        "        np.log(2 * T / delta_test) / n_bound,\n",
        "    )\n",
        "    inv_2 = solve_kl_sup(\n",
        "        inv_1 / (B - mu),  # previously considered as 1 should change to B\n",
        "        (kl + np.log(( T * 4 * np.sqrt(n_bound)) / delta)) / (n_bound),\n",
        "    )\n",
        "    inv_1_inf = solve_kl_inf(\n",
        "        (torch.maximum(torch.tensor(0.0), mu - loss_excess)).mean().item(),\n",
        "        np.log(2 * T / delta_test) / n_bound,\n",
        "    )\n",
        "    inv_2_inf = solve_kl_inf(\n",
        "        inv_1_inf / (mu + gamma_t*B),\n",
        "        (kl + np.log(( T * 4 * np.sqrt(n_bound)) / delta)) / (n_bound),\n",
        "    )\n",
        "    E_t = mu + ((B - mu) * inv_2) - ((mu + gamma_t*B) * inv_2_inf)  # 1 changed to B\n",
        "    kl_over_n = (kl + np.log(( T * 4 * np.sqrt(n_bound)) / delta)) / (n_bound)\n",
        "    return E_t, inv_1/(B - mu), inv_2, inv_1_inf/(mu + gamma_t*B), inv_2_inf, kl_over_n\n",
        "\n",
        "\n",
        "def compute_B_t(B_1, E_ts, gamma_t):\n",
        "    \"\"\"Compute risk of T-step posteriors using the recursive formula:\n",
        "    B_t = E_t + gam * B_{t-1}\n",
        "    \"\"\"\n",
        "    B_ts = [B_1]\n",
        "    for i in range(len(E_ts)):\n",
        "        B_t = B_ts[i] * gamma_t + E_ts[i]\n",
        "        B_ts.append(B_t)\n",
        "    return B_ts\n",
        "\n",
        "########################### compute recursive pac-bayes bound ##################################\n",
        "def compute_risk_rpb(eval_loaders, gamma_t=0.5, delta_test=0.01, delta=0.025):\n",
        "    import torch\n",
        "\n",
        "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "    T = len(eval_loaders)\n",
        "\n",
        "    # Precompute KL divergences for all steps\n",
        "    kl_ts = kl_div_gaussian(eval_loaders)\n",
        "\n",
        "    # Initialize variables\n",
        "    loss_ts = []\n",
        "    E_ts = []\n",
        "    E=[]\n",
        "    B_i=[]\n",
        "    max_B = []\n",
        "    bound_size = []\n",
        "    posterioremploss = []\n",
        "       \n",
        "    # Iterate through each step\n",
        "    for t in range(1, T + 1):\n",
        "        n_bound = len(eval_loaders[t - 1])\n",
        "        print(f\"Current step: {t}\", f\"len of bound: {n_bound}\")\n",
        "        # KL divergence for the current step\n",
        "        kl = kl_ts[t - 1]\n",
        "        upper_limit = eval_loaders[t - 1][0][4]  # B\n",
        "\n",
        "        if t == 1:\n",
        "            risk = torch.cat([element[7] for element in eval_loaders[t - 1]]) # emploss\n",
        "            loss = risk\n",
        "\n",
        "            B_1,_,_ = compute_B_1(loss, kl, T, n_bound, upper_limit,  delta_test, delta, gamma_t)\n",
        "            B_i.append(B_1)\n",
        "            max_B.append(upper_limit)\n",
        "            E.append((0, 0, 0, 0, 0, 0))\n",
        "            loss_ts.append(0)\n",
        "            posterioremploss.append(loss.mean().item())\n",
        "        else:\n",
        "            # Compute (E_t)_{t >= 1}\n",
        "            target_prior = torch.cat([element[1] for element in eval_loaders[t - 2]])\n",
        "            predict_prior = torch.cat([element[2] for element in eval_loaders[t - 2]])\n",
        "            target = torch.cat([element[1] for element in eval_loaders[t - 1]])\n",
        "            predict = torch.cat([element[2] for element in eval_loaders[t - 1]])\n",
        "           \n",
        "            upper_limit = ([element[4] for element in eval_loaders[t - 1]])[0]\n",
        "            max_B_new = upper_limit\n",
        "\n",
        "            emp_losses = mcsampling_excess(\n",
        "                predict_prior,\n",
        "                target_prior,\n",
        "                predict,\n",
        "                target,\n",
        "                upper_limit,\n",
        "                gamma_t=gamma_t,\n",
        "            )\n",
        "            loss_excess = emp_losses[0]\n",
        "            loss_ts.append(emp_losses[0].mean().item())\n",
        "            posterioremploss.append(emp_losses[1])\n",
        "\n",
        "            E_t_, inv_b , Epsilon_plas, invinf_gammab , Epsilon_neg, kl_n = compute_E_t(\n",
        "                torch.tensor(loss_excess), kl, T, gamma_t, n_bound, max_B_new, delta_test, delta\n",
        "            )\n",
        "            if t == T:\n",
        "                B_T_inv = compute_risk_last_invkl( eval_loaders[t - 1], kl, max_B_new, delta_test, delta)\n",
        "            E_ts.append(E_t_)\n",
        "            E.append((E_t_, inv_b, Epsilon_plas, invinf_gammab, Epsilon_neg, kl_n))\n",
        "            B_i.append(B_i[-1] * gamma_t + E_t_)\n",
        "            max_B.append(max_B_new)\n",
        "        bound_size.append(n_bound)\n",
        "\n",
        "    # Compute B_t recursively using B_1, (E_t)_t, and gamma_t\n",
        "    B_ts = compute_B_t(B_1, E_ts, gamma_t)\n",
        "    print(f\"B_ts_list_all steps: {B_ts}, Recursive pac-bayes bound: {B_ts[-1]}\")\n",
        "    print(f\"excess loss: {loss_ts}\")\n",
        "    B_T_MC = 0\n",
        "\n",
        "    return loss_ts, kl_ts, E, B_ts, B_T_MC, max_B , B_T_inv, bound_size, posterioremploss\n",
        "\n",
        "\n",
        "##################### informative inv-kl bound ###########################################\n",
        "def compute_inf_risk_invkl(eval_loader, delta_test=0.01, delta=0.025):\n",
        "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "    n_bound = len(eval_loader)\n",
        "    target = torch.cat([element[1] for element in eval_loader])\n",
        "    predict = torch.cat([element[2] for element in eval_loader])\n",
        "    emploss = torch.cat([element[6] for element in eval_loader]) # emploss\n",
        "    B, kl = eval_loader[0][3], eval_loader[0][4].item() # B, kl\n",
        "\n",
        "    N = n_bound\n",
        "    last_kl_inv = {\n",
        "    \"emploss\": 0,\n",
        "    \"rightside\": 0,\n",
        "    \"fullloss\": 0,\n",
        "    \"leftside\": 0\n",
        "    }\n",
        "    risk = 0\n",
        "\n",
        "    inv_1 = solve_kl_sup(emploss.mean().item(), np.log(1 / delta_test) / n_bound)\n",
        "    risk = solve_kl_sup(\n",
        "        inv_1 / B,\n",
        "        (kl + np.log((2 * np.sqrt(n_bound)) / delta)) / n_bound,\n",
        "    )\n",
        "    confidence_term = math.log(2.0 * math.sqrt(N) / delta) \n",
        "    rightside = ((kl+confidence_term) / N) \n",
        "    loss = B * risk\n",
        "\n",
        "    last_kl_inv[\"emploss\"] = emploss.mean().item()\n",
        "    last_kl_inv[\"rightside\"] = rightside\n",
        "    last_kl_inv[\"fullloss\"] = loss\n",
        "    last_kl_inv[\"leftside\"] = inv_1 / B\n",
        "    return last_kl_inv\n",
        "#####################################################################################\n",
        "def compute_risk_last_invkl( eval_loader,kl,B, delta_test, delta):\n",
        "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "    N = len(eval_loader)\n",
        "    target = torch.cat([element[1] for element in eval_loader])\n",
        "    predict = torch.cat([element[2] for element in eval_loader])\n",
        "    emploss = torch.clamp(\n",
        "            (predict - target) ** 2, max = B ).mean()\n",
        "\n",
        "    loss = 0\n",
        "    n_bound = N\n",
        "  \n",
        "    last_kl_inv = {\n",
        "    \"emploss\": 0,\n",
        "    \"rightside\": 0,\n",
        "    \"fullloss\": 0,\n",
        "    \"leftside\": 0\n",
        "    }\n",
        "    risk = 0\n",
        "    \n",
        "    inv_1 = solve_kl_sup(emploss.mean().item(), np.log(1 / delta_test) / n_bound)\n",
        "    risk = solve_kl_sup(\n",
        "        inv_1 / B,\n",
        "        (kl + np.log((2 * np.sqrt(n_bound)) / delta)) / n_bound,\n",
        "    )\n",
        "    confidence_term = math.log(2.0 * math.sqrt(N) / delta) \n",
        "    rightside = ((kl+confidence_term) / N) \n",
        "    loss = B * risk\n",
        "\n",
        "    last_kl_inv[\"emploss\"] = emploss.mean().item()\n",
        "    last_kl_inv[\"rightside\"] = rightside\n",
        "    last_kl_inv[\"fullloss\"] = loss\n",
        "    last_kl_inv[\"leftside\"] = inv_1 / B\n",
        "    return last_kl_inv\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "rYywNtri14tV"
      },
      "source": [
        "# Recursive PAC-Bayes Model training\n",
        "\n",
        "Consider optimization process of model, save the predicted data as specific nomber of spilits, and calculate bound values using recursion."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "uG5PZt1Y13fT"
      },
      "outputs": [],
      "source": [
        "\n",
        "\n",
        "class PerformanceEval:\n",
        "    def __init__(self, env, model, seed, eval_episodes, epochs=1, batch_size=32): \n",
        "        self.env = env\n",
        "        self.model = model\n",
        "        self.seed = seed\n",
        "        self.device = device\n",
        "        self.eval_episodes = eval_episodes\n",
        "        self.test_episodes = test_episodes\n",
        "        self.gamma =  0.99 \n",
        "        self.learning_rate =  2e-2  \n",
        "        self.epochs = epochs\n",
        "        self.batch_size = batch_size\n",
        "        self.num_repeats = 1\n",
        "        self.loss_history = []  # Store loss for plotting\n",
        "        self.posteriors = []  # List of posterior values\n",
        "        self.model_type = model_type\n",
        "        self.emp_risk_data = [[] for _ in range(self.eval_episodes)]\n",
        "        self.MCnet = None\n",
        "        self.minmax = []\n",
        "        self.kl_NR_I =0\n",
        "        self.local_reparametrization = True\n",
        "\n",
        "        # Load data\n",
        "        self.data = self.get_data()\n",
        "        self.values, self.test_tuples = self.value_tuples()\n",
        "        self.G = self.get_G(self.values, mode=\"train\")  # Get G values for training data\n",
        "\n",
        "        # Loss function & model initialization\n",
        "        self.loss_fn = torch.nn.MSELoss()\n",
        "        self.init_value_net()\n",
        "        self.init_value_net(\"priornet\")\n",
        "        self.priornet.load_state_dict(self.net.state_dict())\n",
        "\n",
        "        # Train model\n",
        "        self.train()\n",
        "\n",
        "        # Plot loss curve\n",
        "        self.plot_loss()\n",
        "\n",
        "        risklist = []\n",
        "        predictions_mv = []\n",
        "        prediction = []\n",
        "        for i in range(self.eval_episodes):\n",
        "            if len(self.emp_risk_data[i]) > 0:\n",
        "                risklist.append(self.emp_risk_data[i])\n",
        "        self.excess, self.kllist, self.E_tlist, self.B_tlist, self.laststep, self.max_B, self.B_T_inv, self.boundsize, self.posterioremploss = compute_risk_rpb(risklist)\n",
        "        self.risk_MC = 0 \n",
        "\n",
        "    def load_path(self, i):\n",
        "        \"\"\"Loads stored trajectories.\"\"\"\n",
        "        path_i = (\n",
        "            f\"Ant_validationround_100000/seed_0{self.seed}/performance_infos_{i}.pt\"\n",
        "        )\n",
        "        return torch.load(path_i, map_location=torch.device('cpu'), weights_only=False)\n",
        "\n",
        "    def get_data(self):\n",
        "        \"\"\"Load data for all episodes.\"\"\"\n",
        "        return [self.load_path(i) for i in range(self.test_episodes)]\n",
        "\n",
        "    def value_tuples(self):\n",
        "        \"\"\"Extract (state, reward) pairs from evaluation and test episodes.\"\"\"\n",
        "        values = [[] for _ in range(self.test_episodes)]\n",
        "\n",
        "        for i, episode in enumerate(self.data[: self.test_episodes]):\n",
        "            trajectory = episode[\"trajectory\"]\n",
        "            values[i].extend((step[2], step[5]) for step in trajectory)\n",
        "\n",
        "        # Split the values into evaluation and test sets\n",
        "        return values[: self.eval_episodes], values[100 :]\n",
        "\n",
        "    def get_G(self, values, mode=\"train\"):\n",
        "        \"\"\"Compute return G using discounted rewards.\"\"\"\n",
        "\n",
        "        G = []\n",
        "        all_returns = []  # List to store all returns for plotting\n",
        "        for episode in values:\n",
        "            episode_G = []\n",
        "            next_return = 0  # Initialize the return for the next step\n",
        "            for s, r in reversed(episode):  # Iterate in reverse order\n",
        "                next_return = r + self.gamma * next_return\n",
        "                episode_G.append((s, next_return))\n",
        "            episode_G.reverse()  # Reverse to restore original order\n",
        "\n",
        "            if any(x in env.lower() for x in [\"ant\", \"cheetah\"]):\n",
        "                sampled_episode_G = episode_G[::5]\n",
        "            elif any(x in env.lower() for x in [\"humanoid\", \"hopper\", \"walker2d\"]):\n",
        "                sampled_episode_G = episode_G[::3] \n",
        "            G.append(sampled_episode_G)\n",
        "\n",
        "            all_returns.extend([t[1] for t in sampled_episode_G])  # collect returns\n",
        "\n",
        "        return G\n",
        "\n",
        "\n",
        "    def init_value_net(self, net_name=\"net\"):\n",
        "        \"\"\"Initialize a new Value Network based on the provided network name and assign it as an attribute.\"\"\"\n",
        "        s, _ = self.G[0][0]  # Extract first state tensor\n",
        "\n",
        "        # Create the network based on the net_name and assign it as an attribute\n",
        "        if net_name == \"net\":\n",
        "            if self.model_type == \"fullvb\":\n",
        "                self.net = ValueNetVBfull((s.shape[0],), self.local_reparametrization).to(self.device)\n",
        "        elif net_name == \"priornet\":\n",
        "            if self.model_type == \"fullvb\":\n",
        "                self.priornet = ValueNetVBfull((s.shape[0],), self.local_reparametrization).to(self.device)\n",
        "        else:\n",
        "            raise ValueError(f\"Unknown network name: {net_name}\")\n",
        "\n",
        "        if net_name == \"net\":\n",
        "            network = self.net\n",
        "            self.optim_net = optim.Adam(network.parameters(), lr=self.learning_rate)#, weight_decay=1e-4)\n",
        "            self.scheduler_net = optim.lr_scheduler.StepLR(\n",
        "            self.optim_net, step_size=10, gamma=0.5\n",
        "        )\n",
        "        elif net_name == \"priornet\":\n",
        "            network = self.priornet\n",
        "            self.optim_priornet = optim.Adam(network.parameters(), lr=self.learning_rate)#, weight_decay=1e-4)\n",
        "            self.scheduler_priornet = optim.lr_scheduler.StepLR(\n",
        "            self.optim_priornet, step_size=10, gamma=0.5\n",
        "        )\n",
        "  \n",
        "\n",
        "    def excess_loss(self, emploss, emploss_prior, gamma_t = 0.05):\n",
        "        \"\"\"Compute the excess loss.\"\"\"\n",
        "        return (emploss - gamma_t * emploss_prior) #.clamp(min=0)\n",
        "    \n",
        "    def E_t(self, excessloss, kl , n_bound , B , delta = 0.025, mu=0, gamma_t = 0.5):\n",
        "        \"\"\"Compute the excess loss.\"\"\"\n",
        "        z_sup = (torch.maximum(torch.tensor(0.0), excessloss - mu)).mean()\n",
        "        z_inf = (torch.maximum(torch.tensor(0.0), mu - excessloss)).mean()\n",
        "        kl_ratio = torch.div(kl + np.log((4 * np.sqrt(n_bound)) / delta), 2*n_bound,)\n",
        "        E_t =  mu + (B - mu) * (z_sup / (B - mu) + torch.sqrt(kl_ratio)) - (mu + gamma_t * B) * (z_inf / (mu + gamma_t * B) - torch.sqrt(kl_ratio))\n",
        "        # E_t = z_sup - z_inf + (3/2 *B * torch.sqrt(kl_ratio))\n",
        "        return E_t\n",
        "    \n",
        "    def get_step_sizes(self, total_episodes):\n",
        "        ratios = [0.03, 0.07, 0.13, 0.25, 0.5, 1.0]\n",
        "        steps =[]\n",
        "        last = 0\n",
        "        for r in ratios:\n",
        "            step = max(last + 1, int(total_episodes * r))\n",
        "            step = min(step, total_episodes)  # Don't exceed total\n",
        "            steps.append(step)\n",
        "            last = step\n",
        "        return steps\n",
        "\n",
        "    def train(self):\n",
        "        \"\"\"Train the value network using progressive episode inclusion.\"\"\"\n",
        "        total_episodes = len(self.G)\n",
        "        half = int(len(self.G)/2)\n",
        "        #step_sizes = self.get_step_sizes(total_episodes)  # 6 portions \n",
        "        step_sizes = [half, total_episodes]  # 2 portion\n",
        "        previous_step = 0\n",
        "\n",
        "        # Precompute episode data for efficiency\n",
        "        episode_data = [\n",
        "            (\n",
        "                torch.stack([s for s, _ in self.G[i]]),\n",
        "                torch.tensor([t for _, t in self.G[i]], dtype=torch.float32),\n",
        "            )\n",
        "            for i in range(total_episodes)\n",
        "        ]\n",
        "       \n",
        "        # Initialize loss tracking\n",
        "        self.emploss_prior = [[] for _ in range(len(step_sizes))]\n",
        "        self.emploss_posterior = [[] for _ in range(len(step_sizes))]\n",
        "\n",
        "        \n",
        "        self.priornet.eval()  # Ensure priornet is in eval mode\n",
        "        B = env_limit\n",
        "        self.B_list = []\n",
        "        self.upper_limit = B\n",
        "        self.tariningloss = []\n",
        "        for i, step in enumerate(step_sizes):\n",
        "            \n",
        "            episode_indices = range(previous_step, min(step, total_episodes))\n",
        "\n",
        "            # use for n_bound in E_t\n",
        "            len_n_bound = 0\n",
        "            for j in range(previous_step, total_episodes):\n",
        "                len_n_bound += len(self.G[j])\n",
        "\n",
        "            # Concatenate states and targets for selected episodes\n",
        "            all_states = torch.cat([episode_data[i][0] for i in episode_indices])\n",
        "            all_targets = torch.cat([episode_data[i][1] for i in episode_indices])\n",
        "\n",
        "            # Create DataLoader\n",
        "            dataset = TensorDataset(all_states, all_targets)\n",
        "            dataloader = DataLoader(dataset, batch_size=self.batch_size, shuffle=True)\n",
        "            max_posterior_loss = torch.tensor(float('-inf'), device=self.device)  # Initialize with a very small value\n",
        "\n",
        "            # Training loop\n",
        "            for epoch in range(self.epochs):\n",
        "                epoch_loss = 0.0\n",
        "                \n",
        "                for states, targets in dataloader:\n",
        "                    states, targets = states.to(self.device), targets.to(self.device)\n",
        "                    ## Forward pass and loss computation\n",
        "                    posterior_predictions = self.net(states)[0]\n",
        "                    \n",
        "                    posteriorloss = (posterior_predictions - targets.view(-1, 1)) ** 2\n",
        "                    posteriorloss = torch.clamp(posteriorloss, max=self.upper_limit)\n",
        "                    max_posterior_loss = torch.max(max_posterior_loss, posteriorloss.max())\n",
        "\n",
        "                    with torch.no_grad():\n",
        "                        prior_prediction = self.priornet(states)[0]\n",
        "                        priorloss = (prior_prediction - targets.view(-1, 1)) ** 2\n",
        "                        priorloss = torch.clamp(priorloss, max=self.upper_limit)\n",
        "                    kl = calculate_kl_terms_informative(self.net, self.priornet)[0]\n",
        "\n",
        "                    # Calculate final loss using excess_loss function\n",
        "                    excess = self.excess_loss(posteriorloss, priorloss)\n",
        "                    #excess = torch.clamp(excess, max=self.upper_limit)\n",
        "                    loss = self.E_t(excess, kl, len_n_bound, self.upper_limit)  #len(all_states), change last input to B  \n",
        "\n",
        "                    \n",
        "                    ## Track loss (convert to scalar values)\n",
        "                    self.emploss_prior[i].append(priorloss.mean().detach().cpu().item())\n",
        "                    self.emploss_posterior[i].append(posteriorloss.mean().detach().cpu().item())\n",
        "                    ## Backward pass and optimization\n",
        "                    self.optim_net.zero_grad()\n",
        "                    loss.backward()\n",
        "\n",
        "                    # Clip gradients for stability\n",
        "                    torch.nn.utils.clip_grad_norm_(self.net.parameters(), max_norm=1)\n",
        "\n",
        "                    # Update the network\n",
        "                    self.optim_net.step()\n",
        "                    \n",
        "                    # Track epoch loss\n",
        "                    epoch_loss += loss.item()\n",
        "\n",
        "                # Log average loss for the epoch\n",
        "                avg_epoch_loss = epoch_loss / len(dataloader)\n",
        "                print(\n",
        "                    f\"Epoch {epoch + 1}/{self.epochs}, Loss: {avg_epoch_loss:.4f}, Episodes Considered: {len(episode_indices)}\"\n",
        "                )\n",
        "                \n",
        "\n",
        "                self.loss_history.append(avg_epoch_loss)\n",
        "                self.tariningloss.append(((torch.maximum(torch.tensor(0.0), excess )).mean(),(torch.maximum(torch.tensor(0.0), 0 - excess)).mean(), kl, loss, excess.mean()))\n",
        "                self.scheduler_net.step()\n",
        "\n",
        "            # Call risk calculation after training on the current portion of episodes\n",
        "            self.risk_input(previous_step, self.upper_limit, step)\n",
        "            self.B_list.append(self.upper_limit)\n",
        "\n",
        "            # get prior estimate for test data from priornet\n",
        "            if i==0:\n",
        "                self.priornet_uninform = copy.deepcopy(self.priornet)\n",
        "\n",
        "            # get informative kl for NR-I\n",
        "            if step == total_episodes:\n",
        "                self.kl_NR_I = calculate_kl_terms_informative(self.net, self.priornet)[0]\n",
        "\n",
        "            # Update priornet with the current trained net\n",
        "            previous_step = step\n",
        "            self.priornet.load_state_dict(self.net.state_dict())\n",
        "        \n",
        "       \n",
        "\n",
        "\n",
        "\n",
        "    def risk_input(self, prev_step, upper_limit, step):\n",
        "        \"\"\"Prepare data for risk computation.\"\"\"\n",
        "\n",
        "        kl = calculate_kl_terms_informative(self.net, self.priornet)[0]\n",
        "        length = 0\n",
        "        for j in range(prev_step, step):\n",
        "            length += len(self.G[j])\n",
        "        max_postloss = torch.tensor(float('-inf'), device=self.device)\n",
        "        min_postloss = torch.tensor(float('inf'), device=self.device)\n",
        "        max_priorloss = torch.tensor(float('-inf'), device=self.device)\n",
        "        min_priorloss = torch.tensor(float('inf'), device=self.device)\n",
        "     \n",
        "        for tuple_list in self.G[prev_step:]:\n",
        "            for state, target in tuple_list:\n",
        "                target = torch.tensor(target, device=self.device).unsqueeze(0)\n",
        "                with torch.no_grad():\n",
        "                    prediction, mean_var = self.net(state.to(self.device))\n",
        "                    prior_prediction = self.priornet(state.to(self.device))[0]\n",
        "                    emploss = (prediction - target.view(-1, 1)) ** 2\n",
        "                    max_postloss = torch.max(max_postloss, emploss)\n",
        "                    min_postloss = torch.min(min_postloss, emploss)\n",
        "                    emploss_ = torch.clamp(emploss, max=upper_limit) \n",
        "                    \n",
        "        \n",
        "                    priorloss = (prior_prediction - target.view(-1, 1)) ** 2\n",
        "                    max_priorloss = torch.max(max_priorloss, priorloss)\n",
        "                    min_priorloss = torch.min(min_priorloss, priorloss)\n",
        "                    priorloss_ = torch.clamp(priorloss, max=upper_limit)  \n",
        "                    excess_loss = self.excess_loss(emploss_, priorloss_, 0.5)\n",
        "\n",
        "                self.emp_risk_data[prev_step].append(\n",
        "                        (state, target, prediction, excess_loss, upper_limit, kl, mean_var, emploss_)  #upper_limit\n",
        "                )\n",
        "        self.minmax.append((max_postloss, min_postloss, max_priorloss, min_priorloss))\n",
        "\n",
        "\n",
        "    def plot_loss(self):\n",
        "        \"\"\"Plot the loss curve.\"\"\"\n",
        "        fig, axes = plt.subplots(5, 2, figsize=(12, 14))\n",
        "        excess_pos = [x[0].detach().cpu().item() for x in self.tariningloss]  # First value (positive excess)\n",
        "        excess_neg = [x[1].detach().cpu().item() for x in self.tariningloss]  # Second value (negative excess)\n",
        "        kl_values = [x[2].detach().cpu().item() for x in self.tariningloss]  # Third value (KL loss)\n",
        "        loss = [x[3].detach().cpu().item() for x in self.tariningloss]  # Fourth value (Total loss)\n",
        "        excess = [x[4].detach().cpu().item() for x in self.tariningloss]  # Fifth value (Excess loss)\n",
        "\n",
        "       # --- B Values & Loss History ---\n",
        "        ax0 = axes[0, 0]\n",
        "        ax0.plot(self.loss_history, label=\"Loss History\", color=\"red\")\n",
        "        ax0.plot(self.B_list[:-1], label=\"B Values\", color=\"green\", linestyle=\"dashed\")\n",
        "        ax0.set_ylabel(\"Loss History & B Values\")\n",
        "        ax0.set_xlabel(\"Episode Count\")\n",
        "        ax0.legend(loc=\"upper left\")\n",
        "\n",
        "        # --- Excess Loss & Total Loss (Twin Axis) ---\n",
        "        ax1 = axes[0, 1]\n",
        "        ax1.plot(excess, label=\"Excess Loss\", color=\"blue\", linestyle=\"dotted\")\n",
        "        ax1.set_ylabel(\"Excess Loss\", color=\"blue\")\n",
        "        ax1.set_xlabel(\"Gradient Update\")\n",
        "        ax1.legend(loc=\"upper left\")\n",
        "\n",
        "        ax1_twin = ax1.twinx()  # Twin axis for Total Loss\n",
        "        ax1_twin.plot(loss, label=\"Total Loss\", color=\"red\", linestyle=\"solid\")\n",
        "        ax1_twin.set_ylabel(\"Total Loss\", color=\"red\")\n",
        "        ax1_twin.legend(loc=\"upper right\")\n",
        "\n",
        "        # --- Excess Positive & Negative (Together) ---\n",
        "        ax2 = axes[1, 0]\n",
        "        ax2.plot(excess_pos, label=\"Positive Excess Loss\", color=\"blue\", linestyle=\"dashed\")\n",
        "        ax2.set_ylabel(\"Excess Positive\")\n",
        "        ax2.set_xlabel(\"Gradient Update\")\n",
        "        ax2.legend(loc=\"upper left\")\n",
        "\n",
        "        ax2_twin = ax2.twinx()  # Twin axis for KL Loss\n",
        "        ax2_twin.plot(excess_neg, label=\"Negative Excess Loss\", color=\"green\", linestyle=\"dotted\")\n",
        "        ax2_twin.set_ylabel(\"Excess Negative\", color=\"green\")\n",
        "        ax2_twin.legend(loc=\"upper right\")\n",
        "\n",
        "\n",
        "        # --- KL Loss (Separate Plot) ---\n",
        "        ax3 = axes[1, 1]\n",
        "        ax3.plot(kl_values, label=\"KL Loss\", color=\"purple\", linestyle=\"solid\")\n",
        "        ax3.set_ylabel(\"KL Loss\", color=\"purple\")\n",
        "        ax3.set_xlabel(\"Gradient Update\")\n",
        "        ax3.legend(loc=\"upper left\")\n",
        "\n",
        "        \n",
        "        # Portions for the remaining plots (ax1 to ax8)\n",
        "        #portions = [0, 1, 2, 3, 4, 5]  # List of portions to loop through\n",
        "        portions = [0, 1]  # List of portions to loop through (2 portions)\n",
        "        for idx, portion in enumerate(portions):\n",
        "            ax = axes[(idx // 2) + 2, idx % 2]  \n",
        "            ax.plot(self.emploss_posterior[portion], label=\"posterior loss\", color=\"blue\")\n",
        "            ax.set_ylabel(f\"posterior loss (portion {portion + 1})\", color=\"blue\")\n",
        "            ax.set_xlabel(f\"portion {portion + 1}\")\n",
        "            ax.legend(loc=\"upper left\")\n",
        "\n",
        "            ax4 = ax.twinx()\n",
        "            ax4.plot(self.emploss_prior[portion], label=\"prior loss\", color=\"green\")\n",
        "            ax4.set_ylabel(f\"prior loss (portion {portion + 1})\", color=\"green\")\n",
        "            ax4.legend(loc=\"upper right\")\n",
        "\n",
        "            plt.tight_layout()\n",
        "\n",
        "        plt.savefig(\n",
        "            f\"Ant_validationround_100000/seed_0{self.seed}/Zplotslogs.png\"\n",
        "        )\n",
        "        plt.close()        \n",
        "\n",
        "\n",
        "    def predict(self):\n",
        "        \"\"\"Predict values for test data and compare with ground truth without batch processing.\"\"\"\n",
        "        self.G_test = self.get_G(self.test_tuples, mode=\"test\")  # Get G values for test data\n",
        "\n",
        "        # Flatten and convert data to tensors\n",
        "        all_states = torch.stack([s for episode in self.G_test for s, _ in episode]).to(self.device)\n",
        "        all_targets = torch.tensor(\n",
        "            [t for episode in self.G_test for _, t in episode], dtype=torch.float32\n",
        "        ).to(self.device)  # Move targets to device early\n",
        "\n",
        "        all_states_train = torch.stack([s for episode in self.G for s, _ in episode]).to(self.device)\n",
        "        all_targets_train = torch.tensor(\n",
        "            [t for episode in self.G for _, t in episode], dtype=torch.float32\n",
        "        ).to(self.device)  # Move targets to device early\n",
        "        max_ = env_limit\n",
        "        total_predictions = []\n",
        "        total_predictions_train = []\n",
        "     \n",
        "    \n",
        "        with torch.no_grad():\n",
        "            for state in all_states:   \n",
        "                predictions, _ = self.net(state) \n",
        "                total_predictions.append(predictions)\n",
        "            for state_t in all_states_train:\n",
        "                predictions_t, _ = self.net(state_t) \n",
        "                total_predictions_train.append(predictions_t)\n",
        "\n",
        "        # Flatten outputs\n",
        "        predictions_ =  torch.stack(total_predictions).flatten()\n",
        "        predictions_train_ =  torch.stack(total_predictions_train).flatten() \n",
        "\n",
        "        # Compute loss\n",
        "        loss = torch.clamp((predictions_ - all_targets) ** 2, max=max_).mean()\n",
        "        loss_train = torch.clamp((predictions_train_ - all_targets_train) ** 2, max=max_).mean()\n",
        "        \n",
        "        return loss, loss_train\n",
        "    \n",
        "    \n",
        "\n",
        "if __name__ == \"__main__\":\n",
        "    import argparse\n",
        "\n",
        "    parser = argparse.ArgumentParser()\n",
        "    \n",
        "    parser.add_argument(\"--env\", type=str, default=\"ant\")\n",
        "    parser.add_argument(\"--model\", type=str, default=\"validationround_100000\")\n",
        "    parser.add_argument(\"--seed\", type=int, default=1)\n",
        "    parser.add_argument(\"--eval_episodes\", type=int, default=100)   #100\n",
        "    parser.add_argument(\"--test_episodes\", type=int, default=200)\n",
        "    args, unknown = parser.parse_known_args()\n",
        "\n",
        "    env = args.env\n",
        "    model = args.model\n",
        "    seed = args.seed\n",
        "    eval_episodes = args.eval_episodes\n",
        "    test_episodes = args.test_episodes\n",
        "    if env == \"ant\":\n",
        "        env_limit = 500000\n",
        " \n",
        "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "    path = f\"Ant_validationround_100000/seed_{str(seed).zfill(2)}\"\n",
        "    model_type = \"fullvb\"  # 'finalvb' or 'fullvb'\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "# Run Recursive bound \n",
        "Addapting \"num_rows\" to eighter 2 or 6 return the bound calculaiton for different depths. \n",
        "\n",
        "A detailed log file containing bound values and various variables used in the calculation will be saved in the specified \"path\" directory."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 5,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 399
        },
        "id": "BRFN2y0505pX",
        "outputId": "3d0fc500-6be0-4910-b140-fe099cde25ed"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Epoch 1/1, Loss: 95161.1364, Episodes Considered: 50\n",
            "Epoch 1/1, Loss: 35925.4280, Episodes Considered: 50\n",
            "KL Divergence values per step: [117.61013793945312, 5.274386405944824]\n",
            "Current step: 1 len of bound: 16592\n",
            "Current step: 2 len of bound: 8305\n",
            "loss_excess avg is equal to (avg posteriorloss - gamma * avg priorloss)\n",
            "loss_excess avg is equal to (exces sup - exces inf) up to 3 decimal places\n",
            "B_ts_list_all steps: [40721.83655092389, 26841.16033271756], Recursive pac-bayes bound: 26841.16033271756\n",
            "excess loss: [0, 1726.773772196417]\n"
          ]
        }
      ],
      "source": [
        "model = PerformanceEval(env, model, seed, eval_episodes)\n",
        "loss_test,loss_train = model.predict()\n",
        "log = {\n",
        "        \"number of eval episodes\": eval_episodes,\n",
        "        \"KLLIST\": model.kllist,\n",
        "        \"excesslist\": model.excess,\n",
        "        \"B_tlist_bound\": model.B_tlist,\n",
        "        \"laststep_invbound\": model.B_T_inv,  #####\n",
        "        \"average squared loss (ground truth)\": loss_test,\n",
        "        \"average squared loss (train data)\": loss_train,\n",
        "}\n",
        "\n",
        "columns = [\"t\", \"n_bound\", \"excess\", \"kl\", \"B_t\", \"posterioremploss\", \"E_t\", \"inv_b\" , \"Epsilon_plas\", \"invinf_gammab\" , \"Epsilon_neg\", \"kl_n\",  \"max_post_loss\", \"min_post_loss\", \"max_priorloss\", \"min_priorloss\"]\n",
        "    #num_rows = 6  # Number of rows \n",
        "num_rows = 2  # Number of rows (2 portions)\n",
        "col_widths = [\n",
        "        max(\n",
        "            len(columns[j]),  # Header width\n",
        "            max(len(f\"{model.boundsize[i]:.5f}\") if j == 1 else\n",
        "                len(f\"{model.excess[i]:.5f}\") if j == 2 else\n",
        "                len(f\"{model.kllist[i]:.5f}\") if j == 3 else\n",
        "                len(f\"{model.B_tlist[i]:.5f}\") if j == 4 else\n",
        "                len(f\"{model.posterioremploss[i]:.5f}\") if j == 5 else\n",
        "                len(f\"{model.E_tlist[i][0]:.5f}\") if j == 6 else\n",
        "                len(f\"{model.E_tlist[i][1]:.5f}\") if j == 7 else\n",
        "                len(f\"{model.E_tlist[i][2]:.5f}\") if j == 8 else\n",
        "                len(f\"{model.E_tlist[i][3]:.5f}\") if j == 9 else\n",
        "                len(f\"{model.E_tlist[i][4]:.5f}\") if j == 10 else\n",
        "                len(f\"{model.E_tlist[i][5]:.5f}\") if j == 11 else\n",
        "                len(f\"{model.minmax[i][0].item():.5f}\") if j == 12 else\n",
        "                len(f\"{model.minmax[i][1].item():.5f}\") if j == 13 else\n",
        "                len(f\"{model.minmax[i][2].item():.5f}\") if j == 14 else\n",
        "                len(f\"{model.minmax[i][3].item():.5f}\") if j == 15 else\n",
        "                len(str(i + 1))  # Row numbers\n",
        "                for i in range(num_rows))\n",
        "        ) + 2  # Add padding for space\n",
        "        for j in range(len(columns))\n",
        "    ]\n",
        "\n",
        "    # Correct separator that aligns exactly with column widths\n",
        "separator = \"+-\" + \"-+-\".join(\"-\" * w for w in col_widths) + \"-+\"\n",
        "\n",
        "# Writing to log file\n",
        "with open(f\"{path}/{model_type}_{num_rows}portion_.log\", \"w\") as f:\n",
        "        now = datetime.now().strftime(\"%Y-%m-%d %H:%M:%S\")\n",
        "        f.write(f\"Log created on: {now}\\n\\n\")\n",
        "        f.write(f\"epoch: { model.epochs}\\n\")\n",
        "        f.write(f\"batch_size: {model.batch_size}\\n\\n\") \n",
        "        f.write(f\"local reparametrization used: {model.local_reparametrization}\\n\")\n",
        "        # Write metadata\n",
        "        for key, value in log.items():\n",
        "            f.write(f\"{key}: {value}\\n\")\n",
        "\n",
        "        # Write header\n",
        "        f.write(separator + \"\\n\")\n",
        "        f.write(\"| \" + \" | \".join(columns[j].ljust(col_widths[j]) for j in range(len(columns))) + \" |\\n\")\n",
        "        f.write(separator + \"\\n\")\n",
        "\n",
        "        # Write rows with correctly aligned values\n",
        "        for i in range(num_rows):\n",
        "            row = [\n",
        "                str(i + 1).rjust(col_widths[0]),  # Row number\n",
        "                f\"{model.boundsize[i]:.5f}\".rjust(col_widths[1]),\n",
        "                f\"{model.excess[i]:.5f}\".rjust(col_widths[2]),\n",
        "                f\"{model.kllist[i]:.5f}\".rjust(col_widths[3]),\n",
        "                f\"{model.B_tlist[i]:.5f}\".rjust(col_widths[4]),\n",
        "                f\"{model.posterioremploss[i]:.5f}\".rjust(col_widths[5]),\n",
        "                f\"{model.E_tlist[i][0]:.5f}\".rjust(col_widths[6]),\n",
        "                f\"{model.E_tlist[i][1]:.5f}\".rjust(col_widths[7]),\n",
        "                f\"{model.E_tlist[i][2]:.5f}\".rjust(col_widths[8]),\n",
        "                f\"{model.E_tlist[i][3]:.5f}\".rjust(col_widths[9]),\n",
        "                f\"{model.E_tlist[i][4]:.5f}\".rjust(col_widths[10]),\n",
        "                f\"{model.E_tlist[i][5]:.5f}\".rjust(col_widths[11]),\n",
        "                f\"{model.minmax[i][0].item():.5f}\".rjust(col_widths[12]),\n",
        "                f\"{model.minmax[i][1].item():.5f}\".rjust(col_widths[13]),\n",
        "                f\"{model.minmax[i][2].item():.5f}\".rjust(col_widths[14]),\n",
        "                f\"{model.minmax[i][3].item():.5f}\".rjust(col_widths[15]),\n",
        "            ]\n",
        "            f.write(\"| \" + \" | \".join(row) + \" |\\n\")\n",
        "            f.write(separator + \"\\n\")\n",
        "\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "# Model Training and Run Non-Recursive cases\n",
        "Changing \"loss_type\" to \"informative\" and \"noninformative\" calculate the NR-NR or NR-I bound respetively.\n",
        "This file contains model training optimizaiton process for non recursive bounds along with final bound calculaiton for each of the 2 cases.\n",
        "\n",
        "A detailed log will be saved in the specified \"path\" directory, containing bound values and other variables used in the calculation."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Epoch 1/1, Loss: 96772.6787, Episodes: 50\n",
            "Epoch 1/1, Loss: 27723.3220, Episodes: 50\n"
          ]
        }
      ],
      "source": [
        "class PerformanceEval:\n",
        "    def __init__(self, env, model, seed, eval_episodes, fold_index=0, epochs=1, batch_size=32):\n",
        "        self.env = env\n",
        "        self.model = model\n",
        "        self.seed = seed\n",
        "        self.eval_episodes = eval_episodes\n",
        "        self.test_episodes = test_episodes\n",
        "        self.device = device\n",
        "        self.gamma = 0.99 \n",
        "        self.learning_rate = 2e-2 \n",
        "        self.epochs = epochs\n",
        "        self.batch_size = batch_size\n",
        "        self.loss_type = loss_type\n",
        "        self.model_type = model_type\n",
        "        self.fold_index = fold_index\n",
        "\n",
        "        self.emp_risk_data_informative = [[] for _ in range(eval_episodes)]\n",
        "        self.posteriors = []\n",
        "        self.loss_history = []\n",
        "        self.minmax = []\n",
        "        self.kl_final = None\n",
        "        self.local_reparametrization = False\n",
        "\n",
        "        self.data = self.get_data()\n",
        "        self.values, self.test_tuples = self.value_tuples()\n",
        "        self.G = self.get_G(self.values, mode=\"train\")\n",
        "\n",
        "        self.loss_fn = torch.nn.MSELoss()\n",
        "        self.init_value_net()\n",
        "        self.init_value_net(\"priornet\")\n",
        "        self.priornet.load_state_dict(self.net.state_dict())\n",
        "\n",
        "        self.train()\n",
        "        self.plot_loss()\n",
        "\n",
        "    def calculate_bound(self):\n",
        "        risklist = [r for r in self.emp_risk_data_informative if r]\n",
        "        self.NR_results = compute_inf_risk_invkl(risklist[0])\n",
        "        self.NR_results_loss = self.NR_results[\"fullloss\"]\n",
        "        self._NR_n_bound = len(risklist[0])\n",
        "\n",
        "    def load_path(self, i):\n",
        "        path_i = (\n",
        "            f\"Ant_validationround_100000/seed_0{self.seed}/performance_infos_{i}.pt\"\n",
        "        )\n",
        "        return torch.load(path_i, map_location=torch.device('cpu'), weights_only=False)\n",
        "\n",
        "    def get_data(self):\n",
        "        return [self.load_path(i) for i in range(self.test_episodes)]\n",
        "\n",
        "    def value_tuples(self):\n",
        "        values = [[] for _ in range(self.test_episodes)]\n",
        "        for i, episode in enumerate(self.data[:self.test_episodes]):\n",
        "            trajectory = episode[\"trajectory\"]\n",
        "            values[i].extend((step[2], step[5]) for step in trajectory)\n",
        "        return values[:self.eval_episodes], values[100:]\n",
        "\n",
        "    def get_G(self, values, mode=\"train\"):\n",
        "        G = []\n",
        "        all_returns = []\n",
        "        for episode in values:\n",
        "            episode_G = []\n",
        "            next_return = 0\n",
        "            for s, r in reversed(episode):\n",
        "                next_return = r + self.gamma * next_return\n",
        "                episode_G.append((s, next_return))\n",
        "            episode_G.reverse()\n",
        "            if any(x in env.lower() for x in [\"ant\", \"cheetah\"]):\n",
        "                sampled_episode_G = episode_G[::5]\n",
        "            elif any(x in env.lower() for x in [\"humanoid\", \"hopper\", \"walker2d\"]):\n",
        "                sampled_episode_G = episode_G[::3] \n",
        "            G.append(sampled_episode_G)\n",
        "            \n",
        "            all_returns.extend([t[1] for t in sampled_episode_G])  # collect returns\n",
        "        return G\n",
        "\n",
        "    def init_value_net(self, net_name=\"net\"):\n",
        "        s, _ = self.G[0][0]\n",
        "        net_cls = ValueNetVBfull\n",
        "        net_instance = net_cls((s.shape[0],), self.local_reparametrization).to(self.device)\n",
        "        \n",
        "\n",
        "        if net_name == \"net\":\n",
        "            self.net = net_instance\n",
        "            self.optim_net = optim.Adam(self.net.parameters(), lr=self.learning_rate)#, weight_decay=1e-4)\n",
        "            self.scheduler_net = optim.lr_scheduler.StepLR(self.optim_net, step_size=10, gamma=0.5)\n",
        "        elif net_name == \"priornet\":\n",
        "            self.priornet = net_instance\n",
        "            self.optim_priornet = optim.Adam(self.priornet.parameters(), lr=self.learning_rate)#, weight_decay=1e-4)\n",
        "            self.scheduler_priornet = optim.lr_scheduler.StepLR(self.optim_priornet, step_size=10, gamma=0.5)\n",
        "\n",
        "    def non_recursive_loss(self, loss, kl, n_bound, B, delta=0.025):\n",
        "        kl_ratio = (kl + np.log((2 * np.sqrt(n_bound)) / delta)) / (2 * n_bound)\n",
        "        return loss + B * torch.sqrt(kl_ratio)\n",
        "\n",
        "    def train(self):\n",
        "        total_episodes = len(self.G)\n",
        "        half = int(len(self.G)/2)\n",
        "        if self.loss_type == \"noninformative\":\n",
        "            step_sizes = [ total_episodes]\n",
        "        else:\n",
        "            step_sizes = [half, total_episodes]\n",
        "        previous_step = 0\n",
        "        self.emploss_prior = [[] for _ in range(len(step_sizes))]\n",
        "        self.emploss_posterior = [[] for _ in range(len(step_sizes))]\n",
        "        episode_data = [\n",
        "            (\n",
        "                torch.stack([s for s, _ in self.G[i]]),\n",
        "                torch.tensor([t for _, t in self.G[i]], dtype=torch.float32),\n",
        "            )\n",
        "            for i in range(total_episodes)\n",
        "        ]\n",
        "\n",
        "        self.kl_train_list = []\n",
        "        self.upper_limit = env_limit\n",
        "\n",
        "        for i, step in enumerate(step_sizes):\n",
        "            episode_indices = range(previous_step, min(step, total_episodes))\n",
        "            n_bound = sum(len(self.G[j]) for j in range(previous_step, step))\n",
        "\n",
        "            all_states = torch.cat([episode_data[i][0] for i in episode_indices])\n",
        "            all_targets = torch.cat([episode_data[i][1] for i in episode_indices])\n",
        "            dataloader = DataLoader(TensorDataset(all_states, all_targets), batch_size=self.batch_size, shuffle=True)\n",
        "            \n",
        "\n",
        "            for epoch in range(self.epochs):\n",
        "                epoch_loss = 0.0\n",
        "                for states, targets in dataloader:\n",
        "                    states, targets = states.to(self.device), targets.to(self.device)\n",
        "                    posterior_predictions = self.net(states)[0]\n",
        "                    posteriorloss = torch.clamp((posterior_predictions - targets.view(-1, 1)) ** 2, max=self.upper_limit)\n",
        "\n",
        "                    with torch.no_grad():\n",
        "                        prior_prediction = self.priornet(states)[0]\n",
        "                        priorloss = torch.clamp((prior_prediction - targets.view(-1, 1)) ** 2, max=self.upper_limit)\n",
        "\n",
        "                    kl = calculate_kl_terms_informative(self.net, self.priornet)[0]\n",
        "                    self.kl_train_list.append(kl.item())\n",
        "                    mse_loss = self.loss_fn(posterior_predictions, targets.view(-1, 1))\n",
        "                    mse_loss = torch.clamp(mse_loss, max=self.upper_limit)\n",
        "                    loss = self.non_recursive_loss(mse_loss, kl, n_bound, self.upper_limit)\n",
        "\n",
        "                    self.emploss_prior[i].append(priorloss.mean().detach().cpu().item())\n",
        "                    self.emploss_posterior[i].append(posteriorloss.mean().detach().cpu().item())\n",
        "\n",
        "                    self.optim_net.zero_grad()\n",
        "                    loss.backward()\n",
        "                    torch.nn.utils.clip_grad_norm_(self.net.parameters(), max_norm=1)\n",
        "                    self.optim_net.step()\n",
        "\n",
        "                    epoch_loss += loss.item()\n",
        "\n",
        "                avg_loss = epoch_loss / len(dataloader)\n",
        "                print(f\"Epoch {epoch + 1}/{self.epochs}, Loss: {avg_loss:.4f}, Episodes: {len(episode_indices)}\")\n",
        "                self.scheduler_net.step()\n",
        "\n",
        "            if self.loss_type == \"informative\" and step == half:\n",
        "                self.priornet.load_state_dict(copy.deepcopy(self.net.state_dict()))\n",
        "\n",
        "            previous_step = step\n",
        "\n",
        "        self.risk_input(previous_step, self.upper_limit, step)\n",
        "\n",
        "    def risk_input(self, prev_step, upper_limit, step):\n",
        "        if self.loss_type == \"informative\":\n",
        "            prev_step = int(len(self.G)/2) #50\n",
        "        elif self.loss_type == \"noninformative\":\n",
        "            prev_step = 0\n",
        "\n",
        "        kl = calculate_kl_terms_informative(self.net, self.priornet)[0]\n",
        "        self.kl_final = kl.item()\n",
        "\n",
        "        max_postloss = torch.tensor(float('-inf'), device=self.device)\n",
        "        min_postloss = torch.tensor(float('inf'), device=self.device)\n",
        "        max_priorloss = torch.tensor(float('-inf'), device=self.device)\n",
        "        min_priorloss = torch.tensor(float('inf'), device=self.device)\n",
        "\n",
        "        for tuple_list in self.G[prev_step:]:\n",
        "            for state, target in tuple_list:\n",
        "                target = torch.tensor(target, device=self.device).unsqueeze(0)\n",
        "                with torch.no_grad():\n",
        "                    prediction, mean_var = self.net(state.to(self.device))\n",
        "                    prior_prediction = self.priornet(state.to(self.device))[0]\n",
        "\n",
        "                    emploss = (prediction - target.view(-1, 1)) ** 2\n",
        "                    priorloss = (prior_prediction - target.view(-1, 1)) ** 2\n",
        "\n",
        "                    max_postloss = torch.max(max_postloss, emploss)\n",
        "                    min_postloss = torch.min(min_postloss, emploss)\n",
        "                    max_priorloss = torch.max(max_priorloss, priorloss)\n",
        "                    min_priorloss = torch.min(min_priorloss, priorloss)\n",
        "\n",
        "                    emploss_ = torch.clamp(emploss, max=upper_limit)\n",
        "                    self.emp_risk_data_informative[prev_step].append(\n",
        "                        (state, target, prediction, upper_limit, kl, mean_var, emploss_)\n",
        "                    )\n",
        "\n",
        "        self.minmax.append((max_postloss, min_postloss, max_priorloss, min_priorloss))\n",
        "    \n",
        "    def plot_loss(self):\n",
        "        import matplotlib.pyplot as plt\n",
        "\n",
        "        fig, axes = plt.subplots(2, 2, figsize=(12, 8))\n",
        "        if self.loss_type == \"noninformative\":\n",
        "            portions = list(range(1))\n",
        "        else:\n",
        "            portions = list(range(2))\n",
        "\n",
        "        # --- B Values & Loss History ---\n",
        "        ax0 = axes[1, 0]\n",
        "        ax0.plot(self.kl_train_list, label=\"kl train\", color=\"red\")\n",
        "        ax0.set_ylabel(\"kl in training\")\n",
        "        ax0.set_xlabel(\"Episode Count\")\n",
        "        ax0.legend(loc=\"upper left\")\n",
        "\n",
        "        for idx, portion in enumerate(portions):\n",
        "            row, col = divmod(idx, 2)\n",
        "            ax = axes[row, col]\n",
        "\n",
        "            if self.emploss_posterior[portion]:\n",
        "                ax.plot(self.emploss_posterior[portion], label=\"Posterior Loss\", color=\"blue\")\n",
        "                ax.set_ylabel(f\"Posterior (portion {portion + 1})\", color=\"blue\")\n",
        "                ax.set_xlabel(f\"Steps\")\n",
        "                ax.legend(loc=\"upper left\")\n",
        "\n",
        "            ax2 = ax.twinx()\n",
        "            if self.emploss_prior[portion]:\n",
        "                ax2.plot(self.emploss_prior[portion], label=\"Prior Loss\", color=\"green\")\n",
        "                ax2.set_ylabel(f\"Prior (portion {portion + 1})\", color=\"green\")\n",
        "                ax2.legend(loc=\"upper right\")\n",
        "\n",
        "        plt.tight_layout()\n",
        "        plt.savefig(f\"Ant_validationround_100000/seed_0{self.seed}/Zplotslogs_NR_{loss_type}.png\")\n",
        "        plt.close()\n",
        "\n",
        "    def predict(self):\n",
        "        self.G_test = self.get_G(self.test_tuples,\"test\")\n",
        "\n",
        "        all_states = torch.stack([s for episode in self.G_test for s, _ in episode]).to(self.device)\n",
        "        all_targets = torch.tensor([t for episode in self.G_test for _, t in episode], dtype=torch.float32).to(self.device)\n",
        "\n",
        "        all_states_train = torch.stack([s for episode in self.G for s, _ in episode]).to(self.device)\n",
        "        all_targets_train = torch.tensor([t for episode in self.G for _, t in episode], dtype=torch.float32).to(self.device)\n",
        "\n",
        "        with torch.no_grad():\n",
        "            predictions_test = torch.stack([self.net(state)[0] for state in all_states]).flatten()\n",
        "            predictions_train = torch.stack([self.net(state)[0] for state in all_states_train]).flatten()\n",
        "\n",
        "        loss_test = torch.clamp((predictions_test - all_targets) ** 2, max=env_limit).mean()\n",
        "        loss_train = torch.clamp((predictions_train - all_targets_train) ** 2, max=env_limit).mean()\n",
        "\n",
        "        return loss_test, loss_train\n",
        "\n",
        "\n",
        "if __name__ == \"__main__\":\n",
        "    import argparse\n",
        "\n",
        "    parser = argparse.ArgumentParser()\n",
        "    parser.add_argument(\"--env\", type=str, default=\"ant\")\n",
        "    parser.add_argument(\"--model\", type=str, default=\"validationround_100000\")\n",
        "    parser.add_argument(\"--seed\", type=int, default=1)\n",
        "    parser.add_argument(\"--eval_episodes\", type=int, default=100)\n",
        "    parser.add_argument(\"--test_episodes\", type=int, default=200)\n",
        "    args, unknown = parser.parse_known_args()\n",
        "\n",
        "    env = args.env\n",
        "    model = args.model\n",
        "    seed = args.seed\n",
        "    eval_episodes = args.eval_episodes\n",
        "    test_episodes = args.test_episodes\n",
        "\n",
        "    if env == \"ant\":\n",
        "        env_limit = 500000\n",
        "    elif env == \"walker2d\":\n",
        "        env_limit = 300000\n",
        "    elif \"humanoid\" in env:\n",
        "        env_limit = 300000\n",
        "    elif env == \"cheetah\":\n",
        "        env_limit = 900000\n",
        "    elif env == \"hopper\":\n",
        "        env_limit = 300000\n",
        "\n",
        "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "    path = f\"Ant_validationround_100000/seed_{str(seed).zfill(2)}\"\n",
        "    loss_type = \"informative\" # or \"noninformative\"\n",
        "    model_type = \"fullvb\"\n",
        "\n",
        "    test_losses = []\n",
        "    train_losses = []\n",
        "    kl_list = []\n",
        "    nr_bounds = []\n",
        "    n_bounds_list = []\n",
        "\n",
        "    model_eval = PerformanceEval(env, model, seed, eval_episodes) #, fold_index=fold_index)\n",
        "    model_eval.calculate_bound()\n",
        "    loss_test, loss_train = model_eval.predict()\n",
        "\n",
        "    log = {\n",
        "        \"eval_episodes\": eval_episodes,\n",
        "        #\"max_B_training\": model_eval.B_list[:-1],\n",
        "        \"kl\": model_eval.kl_final,\n",
        "        \"NR_type\": loss_type,\n",
        "        \"Non-recursive invkl bound\": model_eval.NR_results,\n",
        "        \"Test loss\": loss_test,\n",
        "        \"Train loss\": loss_train,\n",
        "        \"_NR_n_bound\": model_eval._NR_n_bound\n",
        "    }\n",
        "    with open(f\"{path}/NR_{loss_type}_{model_type}_{model_eval.learning_rate}.log\", \"w\") as f:\n",
        "        now = datetime.now().strftime(\"%Y-%m-%d %H:%M:%S\")\n",
        "        f.write(f\"Log created on: {now}\\n\\n\")\n",
        "        f.write(f\"local reparametrization used: {model_eval.local_reparametrization}\\n\")\n",
        "        for key, value in log.items():\n",
        "            f.write(f\"{key}: {value}\\n\")\n"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "display_name": "demo",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.9.18"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
