{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b0d5cd06",
   "metadata": {},
   "outputs": [],
   "source": [
    "from pathlib import Path\n",
    "import pickle\n",
    "import numpy as np\n",
    "from matplotlib import pyplot as plt\n",
    "import matplotlib.cm as cm\n",
    "from ucb.value_functions import FiniteHorizonQRegressor\n",
    "from ucb.models import RBFGP, GPTrainer\n",
    "import ucb.envs\n",
    "import gym\n",
    "from sklearn.metrics import explained_variance_score\n",
    "from functools import partial\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "32cc959f",
   "metadata": {},
   "outputs": [],
   "source": [
    "base_path = Path('../experiments')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a21ad2ae",
   "metadata": {},
   "outputs": [],
   "source": [
    "env_name = \"Uniform Beta + Rotation\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ddbae842",
   "metadata": {},
   "source": [
    "- try uniform p_0 (implemented, running on beta tracking)\n",
    "- run FOVI + TQRL longer to see what happens at the end (going)\n",
    "- think about whether we can justify theoretically the current acquisition function * the occupancy distribution\n",
    "- Bias the prior in something like cartpole to make the prior mean low\n",
    "- try a lower beta (going)\n",
    "- run with the mean policy (going)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "15ae326c",
   "metadata": {},
   "outputs": [],
   "source": [
    "mc_exps = {'GP All data': 'rbf_mountaincar_2022-09-08/17-06-15',\n",
    "           'CB All data': 'cb_mountaincar_2022-09-08/14-18-29',\n",
    "           \"GP TQRL All data\": 'default_2022-09-09/14-06-25',\n",
    "           'CB TQRL All data': 'cb_mc_tqrl_2022-09-09/14-17-11/',\n",
    "           'CB TQRL New': 'mc_cb_tqrl_2022-09-12/11-03-51/',\n",
    "           'GP TQRL New': 'mc_rbf_tqrl_2022-09-12/11-03-09/',\n",
    "           'CB TQRL LCB': 'mc_cb_tqrl_lcb_eval_2022-09-12/11-21-52/',\n",
    "           'GP TQRL LCB': 'mc_rbf_tqrl_q_lcb_eval_2022-09-12/15-11-46/',\n",
    "          }\n",
    "cp_exps = {\"CB All Data\": \"cb_cartpole_2022-09-09/10-30-58/\",\n",
    "           \"RBF All Data\": \"rbf_cartpole_2022-09-08/17-06-23/\",\n",
    "           \"TQRL LCB Visited\": \"TQRL_LCB_visited_cartpole_2022-09-19/14-06-29/\",\n",
    "            }\n",
    "dense_cp_exps = {\n",
    "    # 'GP All data': 'dense_cp_rbf_2022-09-12/15-11-17/',\n",
    "    # 'GP All data 2': 'dense_cp_rbf_2022-09-13/08-43-50/',\n",
    "    # 'GP All data TQRL LCB': 'dense_cp_rbf_lcb_tqrl_2022-09-14/09-35-28/',\n",
    "    # 'GP All data TQRL': 'dense_cp_rbf_tqrl_2022-09-14/23-02-12/',\n",
    "    'TQRL (seeds)': 'TQRL_dense_cartpole_2022-09-19/02-15-39/',\n",
    "    'US LCB': 'US_LCB_dense_cartpole_2022-09-19/05-37-27/',\n",
    "    'TQRL LCB': 'TQRL_LCB_dense_cartpole_2022-09-18/22-45-27/',\n",
    "    'FOVI': 'FOVI_dense_cartpole_2022-09-18/20-23-07/',\n",
    "    'Random': 'RANDOM_LCB_dense_cartpole_2022-09-19/07-52-35/',\n",
    "    'TQRL LCB Visited': 'TQRL_LCB_visited_dense_cartpole_2022-09-19/14-06-20/',\n",
    "    'Greedy': 'greedy_dense_cartpole_2022-09-25/08-54-32/',\n",
    "}\n",
    "\n",
    "uniform_dense_cp_exps = {\n",
    "    'AE-LSVI': 'AE_LSVI_dense_uniform_cartpole_2022-09-23/14-41-37/',\n",
    "    'LSVI': 'LSVI_dense_uniform_cartpole_2022-09-26/15-44-56/',\n",
    "    'AE-LSVI (visited)': 'AE_LSVI_visited_dense_uniform_cartpole_2022-09-23/14-41-26/',\n",
    "    'Random': 'RANDOM_LCB_dense_uniform_cartpole_2022-09-26/16-01-15/',\n",
    "    'US': 'US_LCB_dense_uniform_cartpole_2022-09-25/10-37-02/',\n",
    "    'Greedy': 'greedy_dense_uniform_cartpole_2022-09-25/10-39-49/',\n",
    "}\n",
    "\n",
    "\n",
    "beta_tracking_exps = {\n",
    "    # \"GP All Data\": 'beta_tracking_rbf_2022-09-13/16-22-03/',\n",
    "    # 'GP All data TQRL LCB': 'beta_tracking_rbf_lcb_tqrl_2022-09-14/09-42-37/',\n",
    "    # \"GP All Data TQRL\": 'beta_tracking_rbf_tqrl_2022-09-14/22-40-34/',\n",
    "    # 'TQRL LCB visited': 'beta_tracking_rbf_tqrl_lcb_visited_2022-09-18/13-39-42/',\n",
    "    'AE-LSVI (mean)': 'TQRL_beta_tracking_2022-09-18/22-14-46/',\n",
    "    'US': 'US_LCB_beta_tracking_2022-09-18/23-21-08/',\n",
    "    'AE-LSVI': 'TQRL_LCB_beta_tracking_2022-09-18/21-04-07/',\n",
    "    'LSVI': 'FOVI_beta_tracking_2022-09-18/20-22-36/',\n",
    "    'Random': 'RANDOM_LCB_beta_tracking_2022-09-19/00-05-14/',\n",
    "    'TQRL LCB Visited': 'TQRL_LCB_visited_beta_tracking_2022-09-19/14-05-44/',\n",
    "    'Greedy': 'greedy_beta_tracking_2022-09-20/20-35-26/',\n",
    "}\n",
    "\n",
    "uniform_beta_tracking_exps = {\n",
    "    'AE-LSVI': 'AE_LSVI_uniform_beta_tracking_2022-09-23/14-32-37/',\n",
    "    'LSVI': \"LSVI_uniform_beta_tracking_eval_2022-09-27/14-46-59/\",\n",
    "    'Random': 'RANDOM_LCB_uniform_beta_tracking_2022-09-24/20-20-09/',\n",
    "    'US': 'US_LCB_uniform_beta_tracking_2022-09-23/14-37-29/',\n",
    "    'Greedy': 'greedy_uniform_beta_tracking_2022-09-21/07-19-51/',\n",
    "}\n",
    "\n",
    "beta_rotation_exps = {\n",
    "    # 'GP All Data': 'beta_rotation_rbf_2022-09-13/16-34-16/',\n",
    "    # 'GP All data TQRL LCB': 'beta_rotation_rbf_lcb_tqrl_2022-09-14/09-47-07/',\n",
    "    # 'GP All Data TQRL': 'beta_rotation_rbf_tqrl_2022-09-14/22-40-14/',\n",
    "    # 'TQRL LCB visited': 'beta_rotation_rbf_lcb_tqrl_visited_2022-09-17/12-20-38/',\n",
    "    # 'TQRL (seeds)': 'TQRL_beta_rotation_2022-09-19/00-01-29/',\n",
    "    # 'US LCB': 'US_LCB_beta_rotation_2022-09-19/02-28-00/',\n",
    "    'AE-LSVI': 'TQRL_LCB_beta_rotation_2022-09-18/21-23-37/',\n",
    "    'LSVI': 'FOVI_beta_rotation_2022-09-18/19-42-28/',\n",
    "    'Random': 'RANDOM_LCB_beta_rotation_2022-09-19/02-28-46/',\n",
    "    'TQRL LCB Visited': 'TQRL_LCB_visited_beta_rotation_2022-09-19/22-58-26/',\n",
    "    'US': 'US_LCB_beta_rotation_2022-09-25/06-51-24/',\n",
    "    'Greedy': 'greedy_beta_rotation_2022-09-25/16-36-30/',\n",
    "}\n",
    "\n",
    "uniform_beta_rotation_expts = {\n",
    "    'AE-LSVI': 'TQRL_LCB_uniform_beta_rotation_2022-09-24/17-56-33/',\n",
    "    'LSVI': 'FOVI_uniform_beta_rotation_2022-09-23/06-58-33/',\n",
    "    'Random': 'RANDOM_LCB_uniform_beta_rotation_2022-09-25/23-43-55/',\n",
    "    'US': 'US_LCB_uniform_beta_rotation_2022-09-25/23-43-45/',\n",
    "    'Greedy': 'greedy_uniform_beta_rotation_2022-09-23/16-09-10/',\n",
    "}\n",
    "\n",
    "uniform_weird_gain_expts = {\n",
    "    'LSVI': 'LSVI_uniform_weird_gain_2022-09-21/19-19-40/',\n",
    "    'AE-LSVI Visited': 'AE-LSVI_LCB_visited_uniform_weird_gain_2022-09-21/19-19-30/',\n",
    "}\n",
    "\n",
    "navigation_expts = {\n",
    "    'LSVI': 'FOVI_navigation_2022-09-23/09-44-50/',\n",
    "    'AE-LSVI': 'TQRL_LCB_navigation_2022-09-23/09-48-49/',\n",
    "\n",
    "}\n",
    "    \n",
    "nav_easy_expts = {\n",
    "    'LSVI': 'FOVI_nav_easy_2022-09-23/20-03-00/seed_0/',\n",
    "    'AE-LSVI visited': 'TQRL_LCB_visited_nav_easy_2022-09-23/20-02-42',\n",
    "    'AE-LSVI': 'TQRL_LCB_nav_easy_2022-09-24/21-53-19/',\n",
    "    'US': 'US_LCB_nav_easy_2022-09-24/13-16-01/',\n",
    "    'Greedy': 'GREEDY_LCB_nav_easy_2022-09-25/06-51-09/',\n",
    "    'Random': 'RANDOM_LCB_nav_easy_2022-09-26/00-18-21/',\n",
    "}    \n",
    "\n",
    "uniform_nav_easy_expts = {\n",
    "    'AE-LSVI': 'TQRL_LCB_uniform_nav_easy_2022-09-27/05-55-00/',\n",
    "    'LSVI': 'FOVI_uniform_nav_easy_2022-09-25/17-57-38/',\n",
    "    'Greedy': 'GREEDY_LCB_uniform_nav_easy_2022-09-27/15-31-09/'\n",
    "}\n",
    "    \n",
    "gym_env_names = {\"Mountain Car\": 'densemountaincar-dt10-v0', \n",
    "                 \"Cartpole\": \"cartpoleswingup-v0\",\n",
    "                 \"Dense Cartpole\": 'cartpoleswingup-dense-v0',\n",
    "                 \"Beta Tracking\": 'betatracking-v0',\n",
    "                 \"Beta + Rotation\": 'betarotation-v0',\n",
    "                 'Navigation': \"navigation-v0\"\n",
    "                }\n",
    "horizons = {\"Mountain Car\": 25, \"Cartpole\": 25, \"Dense Cartpole\": 25, \"Uniform Dense Cartpole\": 25, \n",
    "            \"Beta Tracking\": 15, \"Uniform Beta Tracking\": 15, \"Beta + Rotation\": 20, \"Uniform Beta + Rotation\": 20,\n",
    "            \"Weird Gain\": 30,\n",
    "            \"Navigation\": 30, \"Easy Navigation\": 30, \"Uniform Easy Navigation\": 30}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1ebb06d7",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_exps = {\"Mountain Car\": mc_exps,\n",
    "            \"Cartpole\": cp_exps,\n",
    "            \"Dense Cartpole\": dense_cp_exps,\n",
    "            \"Uniform Dense Cartpole\": uniform_dense_cp_exps,\n",
    "            \"Beta Tracking\": beta_tracking_exps,\n",
    "            \"Uniform Beta Tracking\": uniform_beta_tracking_exps,\n",
    "            \"Beta + Rotation\": beta_rotation_exps,\n",
    "            \"Uniform Beta + Rotation\": uniform_beta_rotation_expts,\n",
    "            \"Weird Gain\": uniform_weird_gain_expts,\n",
    "            \"Navigation\": navigation_expts,\n",
    "            \"Easy Navigation\": nav_easy_expts,\n",
    "            \"Uniform Easy Navigation\": uniform_nav_easy_expts,\n",
    "           }\n",
    "exps = all_exps[env_name]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7679d988",
   "metadata": {},
   "outputs": [],
   "source": [
    "def process_single_expt(path):\n",
    "    with path.open('rb') as f:\n",
    "        data = pickle.load(f)\n",
    "    eval_ndata = data['Eval ndata']\n",
    "    eval_returns = np.array(data['Eval Returns'])\n",
    "    mean_returns = np.mean(eval_returns, axis=1)\n",
    "    stderr_returns = np.std(eval_returns, axis=1) / np.sqrt(eval_returns.shape[1])\n",
    "    return {\"Eval ndata\": eval_ndata, \"Mean Returns\": mean_returns, \"Stderr Returns\": stderr_returns,\n",
    "            \"Exploration Returns\": data['Exploration Returns']}\n",
    "\n",
    "def process_seeds(path):\n",
    "    seed_data = []\n",
    "    nseeds = 5 if (path / f'seed_1').exists() else 1\n",
    "    for i in range(nseeds):\n",
    "        seed_path = path / f'seed_{i}' / 'info.pkl'\n",
    "        seed_data.append(process_single_expt(seed_path))\n",
    "    min_length = min([len(dat[\"Mean Returns\"]) for dat in seed_data])\n",
    "    means = np.array([dat['Mean Returns'][:min_length] for dat in seed_data])\n",
    "    seed_mean = np.mean(means, axis=0)\n",
    "    stderr = np.std(means, axis=0) / np.sqrt(means.shape[0])\n",
    "    min_length = min([len(dat[\"Exploration Returns\"]) for dat in seed_data])\n",
    "    expl_means = np.array([dat['Exploration Returns'][:min_length] for dat in seed_data])\n",
    "    expl_seed_mean = np.mean(expl_means, axis=0)\n",
    "    expl_stderr = np.std(expl_means, axis=0) / np.sqrt(expl_means.shape[0])\n",
    "    return {\"Eval ndata\": seed_data[0][\"Eval ndata\"], \"Mean Returns\": seed_mean, \"Stderr Returns\": stderr,\n",
    "            \"Exploration Returns\": expl_seed_mean, \"Expl Stderr Returns\": expl_stderr}\n",
    "\n",
    "def process_expt(path):\n",
    "    single_expt_path = path / 'info.pkl'\n",
    "    if single_expt_path.exists():\n",
    "        return process_single_expt(single_expt_path)\n",
    "    else:\n",
    "        return process_seeds(path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "022a0eaf",
   "metadata": {},
   "outputs": [],
   "source": [
    "cutoffs = {\n",
    "    \"Mountain Car\": 25, \n",
    "    \"Cartpole\": 500,\n",
    "    \"Dense Cartpole\": 500,\n",
    "    \"Uniform Dense Cartpole\": 500,\n",
    "    \"Beta Tracking\": 300,\n",
    "    \"Beta + Rotation\": 400,\n",
    "    \"Weird Gain\": 30,\n",
    "    \"Navigation\": 30,\n",
    "    \"Easy Navigation\": 850}\n",
    "# cutoff = cutoffs[env_name]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d5a288aa",
   "metadata": {},
   "outputs": [],
   "source": [
    "expt_data = {}\n",
    "plot_cutoffs = True\n",
    "cutoff=1000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "07900c49",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.rcParams[\"figure.figsize\"] = (10,6)\n",
    "fig, [ax1, ax2] = plt.subplots(1,2)\n",
    "for expt_name, expt_path in exps.items():\n",
    "    this_data = expt_data[expt_name] = process_expt(base_path / expt_path)\n",
    "    eval_ndata = this_data['Eval ndata']\n",
    "    mean_returns = this_data['Mean Returns']\n",
    "    stderr_returns = this_data['Stderr Returns']\n",
    "    # eval_eps = (np.arange(len(mean_returns)) + 1) * env.horizon\n",
    "    expl_eps = np.arange(len(this_data[\"Exploration Returns\"])) * horizons[env_name]\n",
    "    ax1.plot(eval_ndata, mean_returns, label=expt_name)\n",
    "    ax1.fill_between(eval_ndata, mean_returns - stderr_returns, mean_returns + stderr_returns, alpha=0.2)\n",
    "    ax2.plot(expl_eps, this_data[\"Exploration Returns\"], label=expt_name)\n",
    "if plot_cutoffs:\n",
    "    ax1.axvline(cutoff, color='red')\n",
    "ax1.set_title(\"Test Returns\")\n",
    "ax2.set_title(\"Exploration Returns\")\n",
    "ax1.legend()\n",
    "fig.suptitle(f\"Performance on {env_name}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "36cabd3d",
   "metadata": {},
   "outputs": [],
   "source": [
    "for expt_name, data in expt_data.items():\n",
    "    ndata = np.array(data['Eval ndata'])\n",
    "    mean_returns = data['Mean Returns']\n",
    "    stderr_returns = data['Stderr Returns']\n",
    "    idx = np.argmin(np.abs(ndata - cutoff))\n",
    "    ret = mean_returns[idx]\n",
    "    std_ret = stderr_returns[idx]\n",
    "    print(f\"{expt_name}: {ret:.2f} +- {std_ret:.2f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "22d4af7f",
   "metadata": {},
   "outputs": [],
   "source": [
    "eval_returns = np.array(data['Eval Returns'])\n",
    "mean_returns = np.mean(eval_returns, axis=1)\n",
    "std_returns = np.std(eval_returns, axis=1)\n",
    "eval_eps = (np.arange(len(mean_returns)) + 1) * env.horizon\n",
    "expl_eps = np.arange(len(mean_returns) + 2) * env.horizon"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a7e1d01d",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(eval_eps, mean_returns, label=\"Eval Returns\")\n",
    "plt.fill_between(eval_eps, mean_returns - std_returns, mean_returns + std_returns, alpha=0.2)\n",
    "plt.plot(expl_eps, data[\"Exploration Returns\"], label=\"Expl Returns\")\n",
    "plt.legend()\n",
    "plt.xlabel('Number of Datapoints')\n",
    "plt.ylabel(\"Returns\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e067652a",
   "metadata": {},
   "outputs": [],
   "source": [
    "Xtrain = data['Xtrain']\n",
    "Ytrain = data['Ytrain']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "94d4a36d",
   "metadata": {},
   "outputs": [],
   "source": [
    "for t, row in enumerate(reversed(Ytrain)):\n",
    "    x = np.sort(row)\n",
    "    y = np.arange(len(x)) / float(len(x))\n",
    "    plt.plot(x, y, label=f\"t = {t}\")\n",
    "plt.legend()\n",
    "plt.ylabel(\"Cumulative Density\")\n",
    "plt.xlabel(\"Ytrain Value\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e36d8aa7",
   "metadata": {},
   "outputs": [],
   "source": [
    "X1 = Xtrain[0, ...]\n",
    "Y1 = Ytrain[0, ...]\n",
    "Y1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "157a5620",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.scatter(X1[:, 0], Y1 * 2)\n",
    "plt.xlabel(\"X position\")\n",
    "plt.ylabel(\"Reward x 2\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4e089b47",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_data = data['all_data']\n",
    "all_obs = all_data.next_obs\n",
    "all_obs[:, 0] - (all_data.rewards * 2.2 -  1.2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a450548e",
   "metadata": {},
   "outputs": [],
   "source": [
    "obs_by_time = all_obs.reshape((-1, horizon, all_obs.shape[1]))\n",
    "print(obs_by_time.shape)\n",
    "for ep, ep_data in enumerate(obs_by_time):\n",
    "    plt.scatter(ep_data[:, 0], ep_data[:, 1], color=cm.hot(ep / len(obs_by_time)), label=f\"Episode {ep}\")\n",
    "# plt.scatter(6, 9, s=100, color=\"green\", label=\"goal\")\n",
    "plt.legend()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e72a4a12",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = MaternGP(noise=0.05, jitter=0.3)\n",
    "trainer = GPTrainer(lr=0.01, num_iters=1, seed=0, weight_decay=0.001, constrain_gd=True, load_params=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ff62f590",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_state = trainer.train(model, Xtrain[0, ...], Ytrain[0, ...], 24)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a84018f5",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_state"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2f0e0545",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_state = train_state.replace(params=train_state.params.copy({'params': {\n",
    "        'log_rho': np.array([-3, -3, -3]), \n",
    "        'log_sigma': 1.34}}))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1891d4cd",
   "metadata": {},
   "outputs": [],
   "source": [
    "pred_fn = partial(trainer._pred, train_state=train_state, Xtrain=Xtrain[0, ...], Ytrain=Ytrain[0, ...], train_diag=None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "82835e47",
   "metadata": {},
   "outputs": [],
   "source": [
    "z_scores = (Ytrain[0, ...] - mean[:, 0]) / np.sqrt(var[:, 0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b90dcbf1",
   "metadata": {},
   "outputs": [],
   "source": [
    "Y1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6a37043a",
   "metadata": {},
   "outputs": [],
   "source": [
    "explained_variance_score(Y1, mean)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e4f73c0a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# force xdot and action to zero and see whether the GP can fit that from R->R"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "55ba97c6",
   "metadata": {},
   "outputs": [],
   "source": [
    "# plot what the GP says and whether it can fit the GT data for those values"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4b12cc59",
   "metadata": {},
   "outputs": [],
   "source": [
    "# try forcing the length scales to be tiny and see if that forces an overfit to train data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fc8552f0",
   "metadata": {},
   "outputs": [],
   "source": [
    "data = process_expt(base_path / 'greedy_dense_cartpole_2022-09-20/20-36-44/')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f3decdcc",
   "metadata": {},
   "outputs": [],
   "source": [
    "data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c7ca00f3",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ucb",
   "language": "python",
   "name": "ucb"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
