{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "import custom_gym\n",
    "import matplotlib\n",
    "# matplotlib.use('Agg')\n",
    "%matplotlib inline\n",
    "\n",
    "from  dynamic_systems_rl import SinglePendulum\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "dt=0.01;\n",
    "\n",
    "state_max = torch.tensor([1.,1., 10.]) # theta, dtheta\n",
    "state_min = -state_max\n",
    "mass = 1.0; length=1.0; g= 9.81; coef_viscous = 0.001\n",
    "action_max = torch.tensor([0.5*mass*g*length])\n",
    "w_goal=1\n",
    "w_action=1\n",
    "agent = SinglePendulum(state_min=state_min, state_max=state_max,\n",
    "                       mass=mass, length=length, coef_viscous=coef_viscous,\n",
    "                       w_goal=w_goal, w_action=w_action,\n",
    "                       action_max=action_max, use_gym=False)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "reset_bound = np.array([1.,1.0, 0.1]) # theta, dtheta\n",
    "env = custom_gym.CreateGymEnv(agent, \n",
    "                              reset_bound=reset_bound,     \n",
    "                              dt=dt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from stable_baselines3.common.env_checker import check_env\n",
    "# It will check your custom environment and output additional warnings if needed\n",
    "check_env(env)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from stable_baselines3 import PPO, SAC, DDPG\n",
    "# # Instantiate the agent\n",
    "policy_kwargs = dict(net_arch=[64, 64])\n",
    "\n",
    "model = SAC(\n",
    "    \"MlpPolicy\",\n",
    "    env,\n",
    "    gamma=0.99,\n",
    "    use_sde=True,\n",
    "    sde_sample_freq=4,\n",
    "    learning_rate=1e-3,\n",
    "    verbose=1,\n",
    "    policy_kwargs=policy_kwargs\n",
    ")\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def policy(state):\n",
    "    action, _ = model.predict(state, deterministic=True)\n",
    "    return action\n",
    "\n",
    "def callback(agent, policy, dt, animation=True, file_name='pendu'):\n",
    "    print(\"Testing....\")\n",
    "    state = torch.tensor([[0.,1.,0.]]).view(1,3)\n",
    "    history = []\n",
    "    \n",
    "    T=int(10/dt)\n",
    "    traj = torch.zeros(T+1,3)\n",
    "    cum_reward = torch.tensor([0.]*state.shape[0])\n",
    "    action_history = torch.zeros(state.shape[0],T,1) #batch_size x T x 1\n",
    "    for i in range(T):\n",
    "        action = torch.from_numpy(policy(state)) # batch_size x 1\n",
    "        action_history[:,i,:] = action\n",
    "        r = agent.reward_state_action(state,action)#reward_test(state,action)\n",
    "        cum_reward+=r #reward_test(state,action)\n",
    "        state = agent.forward_simulate(state,action,dt)\n",
    "        traj[i+1,:] = state.clone().cpu()\n",
    "    theta_t = torch.arctan2(traj[:,0],traj[:,1]).cpu()\n",
    "    dtheta_t = traj[:,1].cpu()\n",
    "    from plot_utils import plt_pendulum\n",
    "    plt=plt_pendulum(theta_t.to('cpu').numpy(), \n",
    "                    figsize=5, dt=dt, scale=10, skip=1, animation=animation)\n",
    "    plt.show()\n",
    "    total_reward = torch.mean(cum_reward)\n",
    "    return r.mean().to('cpu'), total_reward.to('cpu')\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "T = 0\n",
    "for i in range(10):\n",
    "    t1 = time.time()\n",
    "    model.learn(total_timesteps=int(1e5))\n",
    "    t2 = time.time()\n",
    "    T = T+ (t2-t1)\n",
    "    print(\"Time Taken: \", t2-t1)\n",
    "    print(\"Total time: \", T)\n",
    "    model.save('pendulum')\n",
    "    r_m, r_cum = callback(agent=agent,policy=policy, dt=dt)\n",
    "    print(\"mean, cum reward, mu: \", r_m, r_cum)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
