{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import sys\n",
    "sys.path.append('../')\n",
    "import time\n",
    "from ttpi import TTPI \n",
    "from pushing_dyn_explicit_double import pusher_slider_sys\n",
    "torch.set_default_dtype(torch.float64)\n",
    "%load_ext autoreload\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "device"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "p_target = torch.tensor([0., 0., 0.]).view(1,-1).to(device)\n",
    "dt=0.025\n",
    "\n",
    "dyn_system =  pusher_slider_sys(p_target=p_target,dt=dt, device=device)\n",
    "dyn_system_test =  pusher_slider_sys(p_target=p_target,dt=dt, device=device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# dim=5\n",
    "w = 0.5\n",
    "state_max_c = torch.tensor([w,w,torch.pi, -0.065, 0.06]).to(device) # (s_x, s_y, s_theta, p_x, p_y) \n",
    "state_min_c =  torch.tensor([-w,-w,-torch.pi, -0.1, -0.06]).to(device)\n",
    "\n",
    "state_max = torch.tensor([w,w,torch.pi, -0.065, 0.06, 3]).to(device) # (s_x, s_y, s_theta, p_x, p_y, face_id) \n",
    "state_min =  torch.tensor([-w,-w,-torch.pi, -0.1, -0.06, 0]).to(device)\n",
    "is_state_c = torch.tensor([1]*len(state_max)).to(device)\n",
    "is_state_c[-1] = 0\n",
    "\n",
    "n = 100\n",
    "n_state = [n]*5\n",
    "n_state[2] = 100\n",
    "n_action = 50\n",
    "\n",
    "v = 0.1\n",
    "action_max_c = torch.tensor([v, v]).to(device) # (p_ddx, p_ddy)\n",
    "action_min_c = torch.tensor([0, -v]).to(device)\n",
    "action_max =  torch.tensor([v,v,3]).to(device)\n",
    "action_min = torch.tensor([0,-v,0]).to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "domain_state_c = []\n",
    "for i in range(len(state_max_c)):#x,dx\n",
    "    x_n = torch.linspace(state_min_c[i],0,int(n_state[i]/2)).to(device)\n",
    "    x_p = torch.linspace(0,state_max_c[i],int(n_state[i]/2)).to(device)[1:]\n",
    "    domain_state_c.append(torch.concat((x_n,x_p),dim=-1))\n",
    "domain_state = domain_state_c  +[torch.tensor([0.,1.,2.,3.]).to(device)]\n",
    "domain_action = [torch.tensor([0.,1.,2.,3.]).to(device)]+ [torch.linspace(action_min_c[i],action_max_c[i],n_action).to(device) for i in range(len(action_max_c))]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def forward_model(state,action):\n",
    "    next_state = dyn_system.dynamics(state,action)\n",
    "    return next_state\n",
    "\n",
    "\n",
    "def reward(state,action):\n",
    "    rewards = -1*dyn_system.cost_func(state,action,scale=w)\n",
    "    return rewards"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_test = 100\n",
    "dim_state = len(domain_state)\n",
    "init_state = torch.empty((n_test,dim_state))\n",
    "for i in range(dim_state):\n",
    "    init_state[:,i] = state_min[i] + torch.rand(n_test).clip(0.25,0.75).to(device)*(state_max[i]-state_min[i])\n",
    "state = init_state.to(device)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tol = torch.tensor([0.03, 0.03, 15/180*torch.pi]).to(device)[:3]\n",
    "\n",
    "def callback(ttpi, state=state, file_name='fig',callback_count=0):\n",
    "    print(\"Testing....\")\n",
    "    \n",
    "    history = []\n",
    "    T=2000\n",
    "    traj = state[:,:].clone()[:,None,:] #bsx(T+1)x6\n",
    "    traj_actions = torch.empty(state.shape[0],T,3).to(device) #bsxTx3\n",
    "    cum_reward = torch.tensor([0.]*state.shape[0]).to(device)\n",
    "    dt_cum = 0\n",
    "    for i in range(T):\n",
    "        t0 = time.time()\n",
    "        \n",
    "        action = ttpi.policy(state)\n",
    "        # print(action[0,0])\n",
    "        t1=time.time()\n",
    "        dt_cum+=(t1-t0)\n",
    "        r = ttpi.reward_normalized(state,action)\n",
    "        cum_reward+=r\n",
    "        state = forward_model(state,action)\n",
    "        traj = torch.concat((traj,state[:,None,:]),dim=1)\n",
    "        traj_actions[:,i,:]=action\n",
    "    print(\"time taken by policy: \", dt_cum/T)\n",
    "    succ_rate = torch.sum(torch.all(torch.abs(state[:,:3])<=tol, dim=1))/n_test\n",
    "    print(f\"Success rate of {n_test} tests is {succ_rate*100}%\")\n",
    "    print(traj_actions[0,:,0])\n",
    "    from plot_utils import plot_planarpush\n",
    "\n",
    "    plot_num = 99 # number of plotted tasks\n",
    "    plt=plot_planarpush(traj[-plot_num:].to('cpu').numpy(),\n",
    "                        traj_actions[-plot_num:].to('cpu').numpy(), \n",
    "                        animation=False, step_skip=40, \n",
    "                        xmax=w,x_target=p_target[0].to('cpu').numpy(),figsize=5, \n",
    "                        save_as=None,\n",
    "                        scale=10)\n",
    "    plt.show()        \n",
    "    return r.mean().to(\"cpu\"), cum_reward.mean().to(\"cpu\")\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ttpi = TTPI(domain_state=domain_state, \n",
    "                domain_action=domain_action, \n",
    "                reward=reward, \n",
    "                normalize_reward=False,\n",
    "                forward_model=forward_model, \n",
    "                gamma=0.99, \n",
    "                rmax_v=100, rmax_a=100, \n",
    "                nswp_v=5, nswp_a=5, \n",
    "                kickrank_v=10, kickrank_a=20,\n",
    "                max_batch_v=10**4,max_batch_a=10**5,\n",
    "                eps_cross_v=1e-3,\n",
    "                eps_cross_a=1e-3,\n",
    "                eps_round_v=1e-4, \n",
    "                eps_round_a=1e-3, \n",
    "                n_samples=10, \n",
    "                verbose=True, \n",
    "                device=device) # action = 'deterministic_tt', 'stochastic_tt', 'random'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "resume= False\n",
    "ttpi.train(n_iter_max=100,n_iter_v=1,\n",
    "        callback=callback, callback_freq=10,\n",
    "        verbose=False, file_name='pushing')\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  },
  "vscode": {
   "interpreter": {
    "hash": "cf96f6c213ba3f9333b362e3bb271376c1f8feeec3b85b92580d68346ee16de3"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
