{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import sys\n",
    "import numpy as np\n",
    "from ttpi import TTPI\n",
    "from dynamic_systems import SinglePendulum\n",
    "torch.set_default_dtype(torch.float64)\n",
    "from plot_utils import plt_pendulum\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "device"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "dt=0.01\n",
    "\n",
    "state_max = torch.tensor([torch.pi,4*torch.pi]).to(device) # (theta,dtheta)\n",
    "state_min = -1*state_max\n",
    "\n",
    "n_state = [100]*2\n",
    "n_action = 200\n",
    "mass = 1.0; length=1.0; g= 9.81; coef_viscous = 0.01\n",
    "action_max = 0.25*mass*g*length # 10% of inertial\n",
    "action_min = -1*action_max"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "domain_state = []\n",
    "for i in range(2):#x,dx\n",
    "    x_n = torch.linspace(state_min[i],0,int(n_state[i]/2)).to(device)\n",
    "    x_p = torch.linspace(0,state_max[i],int(n_state[i]/2)).to(device)[1:]\n",
    "    domain_state.append(torch.concat((x_n,x_p),dim=-1))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tt_utils import get_exponential_discretization\n",
    "\n",
    "domain_action_pos = torch.linspace(action_min,0,int(n_action/2)).to(device)\n",
    "domain_action_neg = -1*domain_action_pos\n",
    "domain_action = [torch.concat((domain_action_neg ,domain_action_pos),dim=-1)]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "w_goal=1.0\n",
    "w_action=0.1\n",
    "w_scale = 1\n",
    "a_max = torch.tensor([action_max]).to(device).view(-1,1)\n",
    "a_min = -1*a_max\n",
    "sys = SinglePendulum(mass=mass,coef_viscous = coef_viscous, \n",
    "                     length=length, state_min=state_min, \n",
    "                     state_max=state_max, action_max=a_max, action_min=a_min, \n",
    "                     dt=dt, w_scale=w_scale, w_goal=w_goal, w_action=w_action, device=device)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def forward_model(state,action):\n",
    "    return sys.forward_simulate(state,action)\n",
    "\n",
    "def reward(state,action):\n",
    "    rewards = sys.reward_state_action(state,action)\n",
    "    return rewards"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sys_test = SinglePendulum(dt=dt, mass=1.*mass,coef_viscous=1*coef_viscous, length=1.*length, state_min=state_min, state_max=state_max, action_max=action_max, action_min=action_min, w_scale=w_scale, w_goal=w_goal, w_action=w_action, device=device)\n",
    "def callback(ttdp, T=10, sys_t=sys, animation=False, file_name='file', callback_count=0):\n",
    "    print(\"Testing....\")\n",
    "    state = torch.tensor([[0.,0.],\n",
    "                        [0,0.1],\n",
    "                        [0,0.2],\n",
    "                        [0.1,-0.1],\n",
    "                        [-0.5,0],\n",
    "                        [0.2,0.2],\n",
    "                        [-0.2,0.3]]).to(device).view(-1,2)\n",
    "    history = []\n",
    "    \n",
    "    dt = sys_t.dt\n",
    "    T=int(T/dt)\n",
    "    traj = torch.zeros(state.shape[0],T,2)\n",
    "    cum_reward = torch.tensor([0.]*state.shape[0]).to(device)\n",
    "    action_history = torch.zeros(state.shape[0],T,1) #batch_size x T x 1\n",
    "    for i in range(T):\n",
    "        action = ttdp.policy(state) # batch_size x 1\n",
    "        action_history[:,i,:] = action.clone().cpu()\n",
    "        r = ttdp.reward_normalized(state,action)#reward_test(state,action)\n",
    "        cum_reward+=r#reward_test(state,action)\n",
    "        state = sys_t.forward_simulate(state,action)\n",
    "        traj[:,i,:] = state.clone().cpu()\n",
    "    print(\"Avg cumulative reward: \", torch.mean(cum_reward))\n",
    "    theta_t = traj[0,:,0].cpu()\n",
    "    dtheta_t = traj[0,:,1].cpu()\n",
    "\n",
    "    from matplotlib import pyplot as plt0\n",
    "    # plt0.plot(theta_t,dtheta_t)\n",
    "    # plt0.xlim([state_min[0].cpu()-0.1,state_max[0].cpu()+0.1])\n",
    "    # plt0.ylim([state_min[1].cpu()-0.1,state_max[1].cpu()+0.1])\n",
    "    # plt0.grid()\n",
    "    # plt0.ylabel(r'Velocity, rad/s')\n",
    "    # plt0.xlabel('Angle, rad',fontsize=10)\n",
    "    # plt0.title('Angle Vs Velocity')\n",
    "    # plt0.show()\n",
    "    \n",
    "    # plt0.plot(np.arange(len(traj[0,:,1]))*dt, np.pi-np.abs(traj[0,:,1].cpu()))\n",
    "    # plt0.grid()\n",
    "    # plt0.ylabel(r'Angle, rad')\n",
    "    # plt0.xlabel('Time in seconds',fontsize=10)\n",
    "    # plt0.title('Angle')\n",
    "    # plt0.show()\n",
    "    plt0.plot(np.arange(len(traj[0,:,1]))*dt,traj[0,:,1].cpu())\n",
    "    plt0.grid()\n",
    "    plt0.ylabel(r'Joint velocity, rad/s')\n",
    "    plt0.xlabel('Time in seconds',fontsize=10)\n",
    "    plt0.title('velocity')\n",
    "    plt0.show()\n",
    "    # plt0.plot(action_history[0,:,0])\n",
    "    # plt0.title('torque')\n",
    "    # plt0.grid()\n",
    "    # plt0.show()\n",
    "\n",
    "    plt=plt_pendulum(theta_t.to('cpu').numpy(), \n",
    "                    figsize=5, dt=dt, scale=10, skip=50, animation=animation)\n",
    "    plt.show()\n",
    "    total_reward = torch.mean(cum_reward)\n",
    "    return r.mean().to('cpu'), total_reward.to('cpu')\n",
    "    "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ttpi = TTPI(domain_state=domain_state, domain_action=domain_action, reward=reward, \n",
    "                forward_model=forward_model, gamma=0.999, n_steps=1,\n",
    "                rmax_v=100, rmax_a=100, nswp_v=10, nswp_a=10, \n",
    "                kickrank_v=10, kickrank_a=10,\n",
    "                max_batch_v=10**4,max_batch_a=10**5,\n",
    "                eps_cross_v=1e-3, \n",
    "                eps_cross_a=1e-3,\n",
    "                eps_round_v=1e-3, \n",
    "                eps_round_a=1e-3, \n",
    "                n_samples=1, normalize_reward=False,\n",
    "                verbose=True,\n",
    "                device=device) # action = 'deterministic_tt', 'stochastic_tt', 'random'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "resume=False\n",
    "ttpi.train(n_iter_max=1000,resume=resume, \n",
    "        callback=callback, callback_freq=25,\n",
    "        verbose=False, file_name='swingup')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  },
  "vscode": {
   "interpreter": {
    "hash": "cf96f6c213ba3f9333b362e3bb271376c1f8feeec3b85b92580d68346ee16de3"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
