{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eeb77f3c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The autoreload extension is already loaded. To reload it, use:\n",
      "  %reload_ext autoreload\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "from sb3_contrib import ARS\n",
    "from stable_baselines3 import PPO\n",
    "from stable_baselines3.common.callbacks import EvalCallback\n",
    "\n",
    "from polyagents.polynomial_policies import PolynomialARSPolicy, PolynomialPPOPolicy\n",
    "from polyagents.utils import get_normalized_vec_env\n",
    "\n",
    "from pickleshare import PickleShareDB\n",
    "db = PickleShareDB('./agent_parameters')\n",
    "\n",
    "tensorboard_log_dir = \"./tensorboard_logs/\"\n",
    "os.makedirs(tensorboard_log_dir, exist_ok=True)\n",
    "\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "11199d9b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# helper function\n",
    "def run_model(model, env, seed=None, episodes=1, verbose=False, render=False, options=None):\n",
    "    rewards = []\n",
    "    if seed:\n",
    "        env.seed(seed)\n",
    "    for _ in range(episodes):\n",
    "        observations = []\n",
    "        reward = 0.0\n",
    "\n",
    "        if options is not None:\n",
    "            obs = env.venv.envs[0].reset(options=options)\n",
    "        else:\n",
    "            obs = env.reset() \n",
    "        if verbose:\n",
    "            print(f'start: {env.get_original_obs()}')\n",
    "        terminated = [False]\n",
    "        truncated = [{'TimeLimit.truncated': False}]\n",
    "\n",
    "        while not (terminated[0] or truncated[0]['TimeLimit.truncated']):\n",
    "            if render:\n",
    "                env.render()\n",
    "            action, _ = model.predict(obs[0], deterministic=True)\n",
    "            obs, r, terminated, truncated = env.step([action])\n",
    "            observations.append(obs)\n",
    "            reward += r\n",
    "        rewards.append(reward)\n",
    "\n",
    "    return np.mean(rewards), np.std(rewards)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "36e93ab8",
   "metadata": {},
   "source": [
    "## Chebyshev Polynomial Basis PPO"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3646fb66",
   "metadata": {},
   "outputs": [],
   "source": [
    "## Hyperparameters determined for CH-PPO\n",
    "learning_rate = 0.0004 # default 0.0003\n",
    "n_steps = 2048 # default = 2048\n",
    "batch_size = 64 # default = 64\n",
    "n_epochs = 5 # default = 10\n",
    "clip_range = 0.4 # default = 0.2\n",
    "clip_range_vf = 0.4 # default = None\n",
    "use_fixed_std_schedule = False # True would use fixed schedule instead of additional polynomial approximator for sigma values\n",
    "degree = 3\n",
    "params = {'learning_rate': learning_rate, 'clip_range': clip_range, 'clip_range_vf': clip_range_vf, 'n_steps': n_steps, 'n_epochs': n_epochs, 'batch_size': batch_size, 'policy_kwargs':dict(degree=degree, use_fixed_std_schedule=use_fixed_std_schedule)}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "7cfe4014",
   "metadata": {},
   "outputs": [],
   "source": [
    "# get normalized env and mode\n",
    "env, eval_env = get_normalized_vec_env(env_name='MountainCarContinuous-v0', render=False)\n",
    "model = PPO(policy=PolynomialPPOPolicy, env=env, tensorboard_log=tensorboard_log_dir, device='cpu', seed=0, **params)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e055a620",
   "metadata": {},
   "outputs": [],
   "source": [
    "## Training\n",
    "eval_callback = EvalCallback(\n",
    "    eval_env,\n",
    "    eval_freq=10_000,\n",
    "    deterministic=True,\n",
    "    render=False,\n",
    "    n_eval_episodes=1,\n",
    ")\n",
    "\n",
    "model.learn(total_timesteps=100000, tb_log_name='CH-PPO', callback=eval_callback)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "fc9e7545",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Save model\n",
    "db['ch_ppo_params'] = model.policy.parameters()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0e472a29",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load model\n",
    "env, eval_env = get_normalized_vec_env(env_name='MountainCarContinuous-v0', render=True)\n",
    "model = PPO(policy=PolynomialPPOPolicy, env=env, tensorboard_log=tensorboard_log_dir, device='cpu', seed=0, policy_kwargs = dict(coeffs=db['ch_ppo_params']))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "7562d2c5",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(np.float32(96.90067), np.float32(0.0))"
      ]
     },
     "execution_count": 31,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Run model\n",
    "run_model(model, env, render=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bd7ebbc6",
   "metadata": {},
   "source": [
    "## Chebyshev Polynomial Basis ARS"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "id": "d0bbcaa8",
   "metadata": {},
   "outputs": [],
   "source": [
    "delta_std = 0.1\n",
    "n_delta = 4 # For degree 3, increasing n_delta = n_top to greater than 4 reduces performance, while for higher degrees performance is increased (e.g. n_top=8)c\n",
    "n_top = 1 # For envs other than MountainCar, 'None' is often the right choice (leads to n_delta = n_top)\n",
    "learning_rate = 0.018\n",
    "zero_policy = False\n",
    "degree = 3\n",
    "\n",
    "params = {'learning_rate': learning_rate, 'delta_std': delta_std, 'n_delta': n_delta, 'n_top': n_top, 'zero_policy': zero_policy, 'policy_kwargs':dict(degree=degree)} "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "id": "df6f2285",
   "metadata": {},
   "outputs": [],
   "source": [
    "env, eval_env = get_normalized_vec_env(env_name='MountainCarContinuous-v0', render=False)\n",
    "model = ARS(policy=PolynomialARSPolicy, env=env, tensorboard_log=tensorboard_log_dir, device='cpu', seed=0, **params)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "36076171",
   "metadata": {},
   "outputs": [],
   "source": [
    "## Training\n",
    "eval_callback = EvalCallback(\n",
    "    eval_env,\n",
    "    eval_freq=10_000,\n",
    "    deterministic=True,\n",
    "    render=False,\n",
    "    n_eval_episodes=1,\n",
    ")\n",
    "\n",
    "model.learn(total_timesteps=150000, tb_log_name='CH-ARS', callback=eval_callback)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 64,
   "id": "981bc9ba",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Save model\n",
    "db['ch_ars_params'] = model.policy.parameters()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c2e95914",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load model\n",
    "env, eval_env = get_normalized_vec_env(env_name='MountainCarContinuous-v0', render=True)\n",
    "model = ARS(policy=PolynomialARSPolicy, env=env, tensorboard_log=tensorboard_log_dir, device='cpu', seed=0, policy_kwargs = dict(coeffs=db['ch_ars_params']), zero_policy=False) # If zero_policy=True, loading of coefficients does not work"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 68,
   "id": "12cc0378",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(np.float32(98.72381), np.float32(0.0))"
      ]
     },
     "execution_count": 68,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Run model\n",
    "run_model(model, env, render=True)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "polynomial-sb3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
