{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import joblib\n",
    "\n",
    "class RingMdp:\n",
    "    \"\"\"\n",
    "    A class to represent a ring-shaped Markov Decision Process (MDP).\n",
    "\n",
    "    Attributes:\n",
    "    -----------\n",
    "    num_states : int\n",
    "        Number of states in the MDP.\n",
    "    p : float\n",
    "        Probability of additional random connections.\n",
    "    num_actions : int\n",
    "        Number of actions in the MDP.\n",
    "    transitions : np.ndarray\n",
    "        Transition probabilities for the MDP.\n",
    "    rewards : np.ndarray\n",
    "        Rewards for the MDP.\n",
    "\n",
    "    Methods:\n",
    "    --------\n",
    "    generate_mdp():\n",
    "        Generates the transition probabilities and rewards for an mdp of type ring. See the paper for details.\n",
    "    \"\"\"\n",
    "    def __init__(self, num_states=10, p=0):\n",
    "        self.num_states = num_states\n",
    "        self.p = p\n",
    "        self.num_actions = 2\n",
    "        self.transitions = np.zeros((2, num_states, num_states))\n",
    "        self.rewards = np.zeros((num_states, 2))\n",
    "        self.generate_mdp()\n",
    "\n",
    "    def generate_mdp(self):\n",
    "        for state in range(self.num_states):\n",
    "            next_state_clockwise = (state + 1) % self.num_states\n",
    "            prev_state_counterclockwise = (state - 1) % self.num_states\n",
    "            \n",
    "            # Action 1: Move clockwise or stay\n",
    "            self.transitions[0, state, [next_state_clockwise, state]] = np.random.uniform(0, 1, 2)\n",
    "\n",
    "            # Action 2: Move counterclockwise or stay\n",
    "            self.transitions[1, state, [prev_state_counterclockwise, state]] = np.random.uniform(0, 1, 2)\n",
    "            \n",
    "            # Additional random connections based on probability p\n",
    "            for other_state in range(self.num_states):\n",
    "                if other_state != state and other_state != next_state_clockwise and np.random.rand() < self.p:\n",
    "                    self.transitions[0, state, other_state] = np.random.uniform(0, 1)\n",
    "                if other_state != state and other_state != prev_state_counterclockwise and np.random.rand() < self.p:\n",
    "                    self.transitions[1, state, other_state] = np.random.uniform(0, 1)\n",
    "            \n",
    "            # Normalize again after adding random connections\n",
    "            self.transitions[0, state, :] /= self.transitions[0, state, :].sum()\n",
    "            self.transitions[1, state, :] /= self.transitions[1, state, :].sum()\n",
    "            \n",
    "            # Generate rewards\n",
    "            self.rewards[state, 0] = np.random.uniform(0, 1)\n",
    "            self.rewards[state, 1] = np.random.uniform(0, 1)\n",
    "\n",
    "def generate_ring_mdps(num_mdps=1000, num_states=10, p=0):\n",
    "    return [RingMdp(num_states, p) for _ in range(num_mdps)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#For this section, we will use the fixed MDP but this can be easily adapted to ring MDP or any other MDP.\n",
    "\n",
    "#We assume the rewards are known for our certainty equivalence setting similar to (The Dependence of Effective Planning Horizon on Model Accuracy Jiang et al 2015)\n",
    "#We compute approximate transitions from n_sample and deduce the optimal policy to compare the bound.\n",
    "\n",
    "class EstimatedRingMDP:\n",
    "    def __init__(self, transitions, rewards, num_states, p):\n",
    "        self.num_states = num_states\n",
    "        self.p = p\n",
    "        self.num_actions = 2\n",
    "        self.transitions = transitions\n",
    "        self.rewards = rewards\n",
    "\n",
    "def estimate_single_mdp(mdp, n_samples):\n",
    "    \"\"\"\n",
    "    Estimate the transition probabilities of a single MDP using sample transitions.\n",
    "\n",
    "    Parameters:\n",
    "    mdp (MDP): The Markov Decision Process to estimate.\n",
    "    n_samples (int): The number of samples to use for estimating the transitions.\n",
    "\n",
    "    Returns:\n",
    "    EstimatedMDP: A new MDP with estimated transition probabilities.\n",
    "    \"\"\"\n",
    "    num_states = mdp.num_states\n",
    "    num_actions = mdp.num_actions\n",
    "    estimated_transitions = np.zeros((num_actions, num_states, num_states))\n",
    "    \n",
    "    for state in range(num_states):\n",
    "        for action in range(num_actions):\n",
    "            next_states = np.random.choice(num_states, size=n_samples, p=mdp.transitions[action, state, :])\n",
    "            \n",
    "            for next_state in range(num_states):\n",
    "                estimated_transitions[action, state, next_state] = np.sum(next_states == next_state) / n_samples\n",
    "\n",
    "    return EstimatedRingMDP(transitions=estimated_transitions, rewards=mdp.rewards, num_states=num_states, p=mdp.p)\n",
    "\n",
    "#This function exists for paralellization purposes\n",
    "def estimate_mdps(mdps, n_samples):\n",
    "    return [estimate_single_mdp(mdp, n_samples) for mdp in mdps]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sample_sizes = [1, 5, 10, 25, 100]\n",
    "V_ests_all = []\n",
    "V_trues_all = []\n",
    "\n",
    "for sample_size in sample_sizes:\n",
    "    ring_mdps = generate_ring_mdps(num_mdps=1000, num_states=10, p=0.125)\n",
    "    true_rewards = np.array([mdp.rewards for mdp in ring_mdps])\n",
    "    true_transitions = np.array([mdp.transitions for mdp in ring_mdps])\n",
    "\n",
    "    estimated_ring_mdps = estimate_mdps(ring_mdps, n_samples=sample_size)\n",
    "    estimated_rewards = np.array([mdp.rewards for mdp in estimated_ring_mdps])\n",
    "    estimated_transitions = np.array([mdp.transitions for mdp in estimated_ring_mdps])\n",
    "\n",
    "    M, A, S  = true_transitions.shape[0], 2, 10          # 1000, 2, 10\n",
    "    betas    = np.arange(1.00, -0.02, -0.02)        # length B = 50\n",
    "    B        = betas.size\n",
    "    K0       = 2                                         # base intra-policies\n",
    "    K_max  = B * K0                   # 300 options\n",
    "    gamma  = 0.99\n",
    "    I      = np.eye(S)[None,None]     # (1,1,S,S)\n",
    "    deliberation_cost = 0.1\n",
    "\n",
    "    options = np.stack([\n",
    "        np.vstack([np.ones(S),  np.zeros(S)]),   # always action 0\n",
    "        np.vstack([np.zeros(S), np.ones(S)])     # always action 1\n",
    "    ])                                           # (3,2,10)\n",
    "\n",
    "    # ---- multi-time transition P_block -----------------------------\n",
    "    # P_base : (M, K0, S, S)\n",
    "    P_one_step_base = np.einsum('kas,masu->mksu', options, true_transitions)\n",
    "\n",
    "    # r_base : (M, K0, S)\n",
    "    r_one_step_base = np.einsum('kas,msa -> mks',  options, true_rewards)\n",
    "\n",
    "    # P_base : (M, K0, S, S)\n",
    "    P_one_step_estimated = np.einsum('kas,masu->mksu', options, estimated_transitions)\n",
    "    # r_estimated : (M, K0, S)\n",
    "    r_one_step_estimated = np.einsum('kas,msa -> mks',  options, estimated_rewards)\n",
    "\n",
    "\n",
    "    # ---- big containers ------------------------------------------------\n",
    "    # big containers\n",
    "    R_all        = np.empty((M, K_max,       S))\n",
    "    P_all        = np.empty((M, K_max,   S, S))\n",
    "    R_all_est    = np.empty_like(R_all)\n",
    "    P_all_est    = np.empty_like(P_all)\n",
    "\n",
    "    cut = np.empty(B, dtype=int)      # prefix-length table\n",
    "    offset = 0\n",
    "\n",
    "    for i, beta in enumerate(betas):\n",
    "\n",
    "        # ---- discounted return R_block --------------------------------\n",
    "        inv  = np.linalg.inv(I - gamma*(1-beta)*P_one_step_base)\n",
    "        R_blk  = np.einsum('mksu,mku->mks', inv, r_one_step_base)\n",
    "\n",
    "        inv_e = np.linalg.inv(I - gamma*(1-beta)*P_one_step_estimated)\n",
    "        R_blk_e = np.einsum('mksu,mku->mks', inv_e, r_one_step_estimated)\n",
    "\n",
    "        # ---- multi-time transition P_block -----------------------------\n",
    "        P_blk   = gamma * beta * np.einsum('mksu,mkuv->mksv',\n",
    "                                       P_one_step_base,      inv)\n",
    "        P_blk_e = gamma * beta * np.einsum('mksu,mkuv->mksv',\n",
    "                                        P_one_step_estimated, inv_e)\n",
    "        # ---- write block straight into the big tensors ----------------\n",
    "        sl = slice(offset, offset + K0)          # exactly 3 columns\n",
    "        R_all[:, sl]     = R_blk\n",
    "        P_all[:, sl]     = P_blk\n",
    "        R_all_est[:, sl] = R_blk_e\n",
    "        P_all_est[:, sl] = P_blk_e\n",
    "\n",
    "        offset      += K0                        # advance by 3\n",
    "        cut[i]       = offset                    # store prefix length\n",
    "\n",
    "    def prefix(b):\n",
    "        end = cut[b]                    # e.g. 3, 6, 9, …, 300\n",
    "        return (R_all[:, :end], P_all[:, :end],\n",
    "                R_all_est[:, :end], P_all_est[:, :end])\n",
    "\n",
    "\n",
    "    def option_value_iteration(R, P, *, tol=1e-3, max_iter=1000):\n",
    "        \"\"\"\n",
    "        R : (M,K,S)      multi-time option rewards\n",
    "        P : (M,K,S,S)    multi-time transition matrices\n",
    "        \"\"\"\n",
    "        M, K, S = R.shape\n",
    "        V  = np.zeros((M, S))\n",
    "        pi = np.zeros((M, S), dtype=int)\n",
    "\n",
    "        R_fee = R - deliberation_cost\n",
    "\n",
    "        for _ in range(max_iter):\n",
    "            # Q(m,k,s) = R + Σ_u P(m,k,s,u) V(m,u)\n",
    "            Q = R_fee + np.einsum('mksu,mu->mks', P, V)\n",
    "\n",
    "            V_new  = Q.max(axis=1)       # (M,S)\n",
    "            pi_new = Q.argmax(axis=1)    # (M,S)\n",
    "\n",
    "            if np.max(np.abs(V_new - V)) < tol:\n",
    "                break\n",
    "            V, pi = V_new, pi_new\n",
    "\n",
    "        return V_new, pi_new\n",
    "\n",
    "    # Create lists to store results for all prefixes\n",
    "    V_stars = []\n",
    "    pi_stars = []\n",
    "    V_star_ests = []\n",
    "    pi_star_ests = []\n",
    "\n",
    "    # Process all prefixes in parallel\n",
    "    results = joblib.Parallel(n_jobs=-1)(\n",
    "        joblib.delayed(lambda b: (\n",
    "            option_value_iteration(*prefix(b)[:2]),\n",
    "            option_value_iteration(*prefix(b)[2:])\n",
    "        ))(b) for b in range(B)\n",
    "    )\n",
    "\n",
    "    # Unpack results\n",
    "    for (V_star, pi_star), (V_star_est, pi_star_est) in results:\n",
    "        V_stars.append(V_star)\n",
    "        pi_stars.append(pi_star)\n",
    "        V_star_ests.append(V_star_est)\n",
    "        pi_star_ests.append(pi_star_est)\n",
    "\n",
    "    # Convert to numpy arrays for easier indexing\n",
    "    V_stars = np.array(V_stars)\n",
    "    pi_stars = np.array(pi_stars)\n",
    "    V_star_ests = np.array(V_star_ests)\n",
    "    pi_star_ests = np.array(pi_star_ests)\n",
    "\n",
    "    def evaluate_option_policy(R, P, pi):\n",
    "        M, K, S = R.shape\n",
    "\n",
    "        # Create index helpers\n",
    "        rows = np.arange(M)[:, None]        # shape (M, 1)\n",
    "        cols = np.arange(S)[None, :]        # shape (1, S)\n",
    "        # Fancy indexing to get R_pi[m, s] = R[m, pi[m, s], s]\n",
    "        R_pi = R[rows, pi, cols] - deliberation_cost           # shape (M, S)\n",
    "\n",
    "        # Get P_pi[m, s, s'] = P[m, pi[m, s], s, s']\n",
    "        P_pi = P[rows, pi, cols, :]         # shape (M, S, S)\n",
    "\n",
    "        # Solve (I - P_pi[m]) V = R_pi[m] for each m\n",
    "        I = np.eye(S)[None, :, :]           # shape (1, S, S) for broadcasting\n",
    "        inv = np.linalg.inv(I - P_pi)\n",
    "        V_pi = np.einsum('msu,mu->ms', inv, R_pi)  # shape (M, S)\n",
    "\n",
    "        return V_pi\n",
    "\n",
    "    # R_beta, P_beta are the true-model tensors for that prefix\n",
    "    # pi_star       is what option_value_iteration gave you on the *estimated* model\n",
    "    V_true = evaluate_option_policy(R_all,P_all, pi_stars[-1])\n",
    "\n",
    "    V_ests = joblib.Parallel(n_jobs=-1)(\n",
    "        joblib.delayed(lambda b: evaluate_option_policy(R_all,P_all, pi_star_ests[b]))(b) \n",
    "        for b in range(B)\n",
    "    )\n",
    "\n",
    "    # Convert to numpy arrays for easier indexing\n",
    "    V_trues = np.broadcast_to(V_true, (B, *V_true.shape))\n",
    "    V_ests = np.array(V_ests)\n",
    "\n",
    "    V_ests_all.append(V_ests)\n",
    "    V_trues_all.append(V_trues)\n",
    "\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Export the plot to PDF insteadd of showing in window\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "# Reset to matplotlib defaults (fonts, etc.)\n",
    "plt.rcdefaults()\n",
    "\n",
    "plt.figure(figsize=(6, 4))\n",
    "\n",
    "for i in range(len(V_ests_all)):\n",
    "    average_loss = np.max((V_trues_all[i]-V_ests_all[i])/V_trues_all[i], axis=2).mean(axis=1)\n",
    "    std_loss = np.max((V_trues_all[i]-V_ests_all[i])/V_trues_all[i], axis=2).std(axis=1)/np.sqrt(1000)\n",
    "    line = plt.plot(betas, average_loss, label=f'n = {sample_sizes[i]}')\n",
    "    plt.errorbar(betas, average_loss, yerr=std_loss, fmt='none', capsize=5, alpha=0.3, color=line[0].get_color())\n",
    "    min_idx = np.argmin(average_loss)\n",
    "    plt.plot(betas[min_idx], average_loss[min_idx], '*', markersize=10, color='red')\n",
    "\n",
    "plt.xlabel(r'$\\beta_{\\mathrm{eval}}$', fontsize=14)\n",
    "plt.ylabel('Commitment loss', fontsize=14)\n",
    "plt.grid(True)\n",
    "plt.legend(fontsize=14)\n",
    "plt.tick_params(labelsize=14)\n",
    "plt.tight_layout()\n",
    "\n",
    "# Instead of plt.show(), save to PDF\n",
    "plt.savefig(\"commitment_loss_vs_beta.pdf\", format=\"pdf\")\n",
    "plt.show()\n",
    "plt.close()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sample_size = 1\n",
    "deliberation_costs = [0.01,0.1,0.2,0.3]\n",
    "V_ests_all_deliberation_costs = []\n",
    "V_trues_all_deliberation_costs = []\n",
    "\n",
    "for deliberation_cost in deliberation_costs:\n",
    "    print(\"deliberation cost: \", deliberation_cost)\n",
    "    ring_mdps = generate_ring_mdps(num_mdps=1000, num_states=10, p=0.125)\n",
    "    true_rewards = np.array([mdp.rewards for mdp in ring_mdps])\n",
    "    true_transitions = np.array([mdp.transitions for mdp in ring_mdps])\n",
    "\n",
    "    estimated_ring_mdps = estimate_mdps(ring_mdps, n_samples=sample_size)\n",
    "    estimated_rewards = np.array([mdp.rewards for mdp in estimated_ring_mdps])\n",
    "    estimated_transitions = np.array([mdp.transitions for mdp in estimated_ring_mdps])\n",
    "\n",
    "    M, A, S  = true_transitions.shape[0], 2, 10          # 1000, 2, 10\n",
    "    betas    = np.arange(1.00, -0.02, -0.02)        # length B = 100\n",
    "    B        = betas.size\n",
    "    K0       = 3                                         # base intra-policies\n",
    "    K_max  = B * K0                   # 300 options\n",
    "    gamma  = 0.99\n",
    "    I      = np.eye(S)[None,None]     # (1,1,S,S)\n",
    "\n",
    "    options = np.stack([\n",
    "        np.full((A, S), 0.5),                    # 50 / 50\n",
    "        np.vstack([np.ones(S),  np.zeros(S)]),   # always action 0\n",
    "        np.vstack([np.zeros(S), np.ones(S)])     # always action 1\n",
    "    ])                                           # (3,2,10)\n",
    "\n",
    "    # ---- multi-time transition P_block -----------------------------\n",
    "    # P_base : (M, K0, S, S)\n",
    "    P_one_step_base = np.einsum('kas,masu->mksu', options, true_transitions)\n",
    "\n",
    "    # r_base : (M, K0, S)\n",
    "    r_one_step_base = np.einsum('kas,msa -> mks',  options, true_rewards)\n",
    "\n",
    "    # P_base : (M, K0, S, S)\n",
    "    P_one_step_estimated = np.einsum('kas,masu->mksu', options, estimated_transitions)\n",
    "    # r_estimated : (M, K0, S)\n",
    "    r_one_step_estimated = np.einsum('kas,msa -> mks',  options, estimated_rewards)\n",
    "\n",
    "\n",
    "    # ---- big containers ------------------------------------------------\n",
    "    # big containers\n",
    "    R_all        = np.empty((M, K_max,       S))\n",
    "    P_all        = np.empty((M, K_max,   S, S))\n",
    "    R_all_est    = np.empty_like(R_all)\n",
    "    P_all_est    = np.empty_like(P_all)\n",
    "\n",
    "    cut = np.empty(B, dtype=int)      # prefix-length table\n",
    "    offset = 0\n",
    "\n",
    "    for i, beta in enumerate(betas):\n",
    "\n",
    "        # ---- discounted return R_block --------------------------------\n",
    "        inv  = np.linalg.inv(I - gamma*(1-beta)*P_one_step_base)\n",
    "        R_blk  = np.einsum('mksu,mku->mks', inv, r_one_step_base)\n",
    "\n",
    "        inv_e = np.linalg.inv(I - gamma*(1-beta)*P_one_step_estimated)\n",
    "        R_blk_e = np.einsum('mksu,mku->mks', inv_e, r_one_step_estimated)\n",
    "\n",
    "        # ---- multi-time transition P_block -----------------------------\n",
    "        P_blk   = gamma * beta * np.einsum('mksu,mkuv->mksv',\n",
    "                                       P_one_step_base,      inv)\n",
    "        P_blk_e = gamma * beta * np.einsum('mksu,mkuv->mksv',\n",
    "                                        P_one_step_estimated, inv_e)\n",
    "        # ---- write block straight into the big tensors ----------------\n",
    "        sl = slice(offset, offset + K0)          # exactly 3 columns\n",
    "        R_all[:, sl]     = R_blk\n",
    "        P_all[:, sl]     = P_blk\n",
    "        R_all_est[:, sl] = R_blk_e\n",
    "        P_all_est[:, sl] = P_blk_e\n",
    "\n",
    "        offset      += K0                        # advance by 3\n",
    "        cut[i]       = offset                    # store prefix length\n",
    "\n",
    "    def prefix(b):\n",
    "        end = cut[b]                    # e.g. 3, 6, 9, …, 300\n",
    "        return (R_all[:, :end], P_all[:, :end],\n",
    "                R_all_est[:, :end], P_all_est[:, :end])\n",
    "\n",
    "\n",
    "    def option_value_iteration(R, P, *, tol=1e-3, max_iter=1000):\n",
    "        \"\"\"\n",
    "        R : (M,K,S)      multi-time option rewards\n",
    "        P : (M,K,S,S)    multi-time transition matrices\n",
    "        \"\"\"\n",
    "        M, K, S = R.shape\n",
    "        V  = np.zeros((M, S))\n",
    "        pi = np.zeros((M, S), dtype=int)\n",
    "\n",
    "        R_fee = R - deliberation_cost\n",
    "\n",
    "        for _ in range(max_iter):\n",
    "            # Q(m,k,s) = R + Σ_u P(m,k,s,u) V(m,u)\n",
    "            Q = R_fee + np.einsum('mksu,mu->mks', P, V)\n",
    "\n",
    "            V_new  = Q.max(axis=1)       # (M,S)\n",
    "            pi_new = Q.argmax(axis=1)    # (M,S)\n",
    "\n",
    "            if np.max(np.abs(V_new - V)) < tol:\n",
    "                break\n",
    "            V, pi = V_new, pi_new\n",
    "            \n",
    "        return V_new, pi_new\n",
    "\n",
    "    # Create lists to store results for all prefixes\n",
    "    V_stars = []\n",
    "    pi_stars = []\n",
    "    V_star_ests = []\n",
    "    pi_star_ests = []\n",
    "\n",
    "    # Process all prefixes in parallel\n",
    "    results = joblib.Parallel(n_jobs=-1)(\n",
    "        joblib.delayed(lambda b: (\n",
    "            option_value_iteration(*prefix(b)[:2]),\n",
    "            option_value_iteration(*prefix(b)[2:])\n",
    "        ))(b) for b in range(B)\n",
    "    )\n",
    "\n",
    "    # Unpack results\n",
    "    for (V_star, pi_star), (V_star_est, pi_star_est) in results:\n",
    "        V_stars.append(V_star)\n",
    "        pi_stars.append(pi_star)\n",
    "        V_star_ests.append(V_star_est)\n",
    "        pi_star_ests.append(pi_star_est)\n",
    "\n",
    "    # Convert to numpy arrays for easier indexing\n",
    "    V_stars = np.array(V_stars)\n",
    "    pi_stars = np.array(pi_stars)\n",
    "    V_star_ests = np.array(V_star_ests)\n",
    "    pi_star_ests = np.array(pi_star_ests)\n",
    "\n",
    "    def evaluate_option_policy(R, P, pi):\n",
    "        M, K, S = R.shape\n",
    "\n",
    "        # Create index helpers\n",
    "        rows = np.arange(M)[:, None]        # shape (M, 1)\n",
    "        cols = np.arange(S)[None, :]        # shape (1, S)\n",
    "        # Fancy indexing to get R_pi[m, s] = R[m, pi[m, s], s]\n",
    "        R_pi = R[rows, pi, cols] - deliberation_cost           # shape (M, S)\n",
    "\n",
    "        # Get P_pi[m, s, s'] = P[m, pi[m, s], s, s']\n",
    "        P_pi = P[rows, pi, cols, :]         # shape (M, S, S)\n",
    "\n",
    "        # Solve (I - P_pi[m]) V = R_pi[m] for each m\n",
    "        I = np.eye(S)[None, :, :]           # shape (1, S, S) for broadcasting\n",
    "        inv = np.linalg.inv(I - P_pi)\n",
    "        V_pi = np.einsum('msu,mu->ms', inv, R_pi)  # shape (M, S)\n",
    "\n",
    "        return V_pi\n",
    "\n",
    "    # R_beta, P_beta are the true-model tensors for that prefix\n",
    "    # pi_star       is what option_value_iteration gave you on the *estimated* model\n",
    "    V_true = evaluate_option_policy(R_all,P_all, pi_stars[-1])\n",
    "\n",
    "    V_ests = joblib.Parallel(n_jobs=-1)(\n",
    "        joblib.delayed(lambda b: evaluate_option_policy(R_all,P_all, pi_star_ests[b]))(b) \n",
    "        for b in range(B)\n",
    "    )\n",
    "\n",
    "    # Convert to numpy arrays for easier indexing\n",
    "    V_trues = np.broadcast_to(V_true, (B, *V_true.shape))\n",
    "    V_ests = np.array(V_ests)\n",
    "\n",
    "    V_ests_all_deliberation_costs.append(V_ests)\n",
    "    V_trues_all_deliberation_costs.append(V_trues)\n",
    "\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "\n",
    "# Reset to matplotlib defaults (fonts, etc.)\n",
    "plt.rcdefaults()\n",
    "\n",
    "# Create the plot for deliberation costs\n",
    "plt.figure(figsize=(6, 4))\n",
    "average_losses = []\n",
    "# Plot for each deliberation cost value\n",
    "for i in range(len(V_ests_all_deliberation_costs)):\n",
    "    average_loss = np.max((V_trues_all_deliberation_costs[i]-V_ests_all_deliberation_costs[i])/V_trues_all_deliberation_costs[i], axis=2).mean(axis=1)\n",
    "    std_loss = np.max((V_trues_all_deliberation_costs[i]-V_ests_all_deliberation_costs[i])/V_trues_all_deliberation_costs[i], axis=2).std(axis=1)/np.sqrt(1000)\n",
    "    line = plt.plot(betas, average_loss, label=f'$C_{{\\\\mathrm{{max}}}} = {deliberation_costs[i]}$')\n",
    "    \n",
    "    # Add error bars using std_loss\n",
    "    plt.errorbar(betas, average_loss, yerr=std_loss, fmt='none', capsize=5, alpha=0.3, color=line[0].get_color())\n",
    "    \n",
    "    # Find and plot minimum point with red star\n",
    "    min_idx = np.argmin(average_loss)\n",
    "    plt.plot(betas[min_idx], average_loss[min_idx], '*', markersize=10, color='red')\n",
    "\n",
    "    average_losses.append(average_loss)\n",
    "\n",
    "plt.xlabel(r'$\\beta_{\\mathrm{eval}}$', fontsize=14)\n",
    "plt.ylabel('Commitment loss (model error)', fontsize=14)\n",
    "plt.grid(True)\n",
    "plt.legend(fontsize=14)\n",
    "plt.tick_params(labelsize=14)\n",
    "plt.tight_layout()\n",
    "plt.savefig('commitment_loss_vs_beta_deliberation_costs.pdf', bbox_inches='tight', dpi=300)\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Font size parameter that can be adjusted\n",
    "FONT_SIZE = 18\n",
    "\n",
    "# Create a figure with two subplots side by side\n",
    "plt.figure(figsize=(12, 4))\n",
    "\n",
    "# Left subplot: deliberation costs\n",
    "plt.subplot(1, 2, 1)\n",
    "for i in range(len(V_ests_all_deliberation_costs)):\n",
    "    average_loss = np.max((V_trues_all_deliberation_costs[i]-V_ests_all_deliberation_costs[i])/V_trues_all_deliberation_costs[i], axis=2).mean(axis=1)\n",
    "    std_loss = np.max((V_trues_all_deliberation_costs[i]-V_ests_all_deliberation_costs[i])/V_trues_all_deliberation_costs[i], axis=2).std(axis=1)/np.sqrt(1000)\n",
    "    line = plt.plot(betas, average_loss, label=f'$C_{{\\\\mathrm{{max}}}} = {deliberation_costs[i]}$')\n",
    "    \n",
    "    # Add error bars using std_loss\n",
    "    plt.errorbar(betas, average_loss, yerr=std_loss, fmt='none', capsize=5, alpha=0.3, color=line[0].get_color())\n",
    "    \n",
    "    # Find and plot minimum point with red star\n",
    "    min_idx = np.argmin(average_loss)\n",
    "    plt.plot(betas[min_idx], average_loss[min_idx], '*', markersize=10, color='red')\n",
    "\n",
    "plt.xlabel(r'$\\beta_{\\mathrm{eval}}$', fontsize=FONT_SIZE)\n",
    "plt.ylabel('C.L (model error)', fontsize=FONT_SIZE)\n",
    "plt.grid(True)\n",
    "plt.legend(fontsize=14)\n",
    "plt.tick_params(labelsize=14)\n",
    "\n",
    "# Right subplot: Sample sizes\n",
    "plt.subplot(1, 2, 2)\n",
    "for i in range(len(V_ests_all)):\n",
    "    average_loss = np.max((V_trues_all[i]-V_ests_all[i])/V_trues_all[i], axis=2).mean(axis=1)\n",
    "    std_loss = np.max((V_trues_all[i]-V_ests_all[i])/V_trues_all[i], axis=2).std(axis=1)/np.sqrt(1000)\n",
    "    line = plt.plot(betas, average_loss, label=f'$n = {sample_sizes[i]}$')\n",
    "    \n",
    "    # Add error bars using std_loss\n",
    "    plt.errorbar(betas, average_loss, yerr=std_loss, fmt='none', capsize=5, alpha=0.3, color=line[0].get_color())\n",
    "    \n",
    "    # Find and plot minimum point with red star\n",
    "    min_idx = np.argmin(average_loss)\n",
    "    plt.plot(betas[min_idx], average_loss[min_idx], '*', markersize=10, color='red')\n",
    "\n",
    "plt.xlabel(r'$\\beta_{\\mathrm{eval}}$', fontsize=FONT_SIZE)\n",
    "plt.grid(True)\n",
    "plt.legend(fontsize=14)\n",
    "plt.tick_params(labelsize=14)\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig('commitment_loss_vs_beta_combined.pdf', bbox_inches='tight', dpi=300)\n",
    "plt.show()\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "myenv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
