{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## This document is about training a policy for multiple orientations\n",
    "import torch\n",
    "import sys\n",
    "import time\n",
    "\n",
    "from ttpi import TTPI \n",
    "from pivoting_wall import pwall_env\n",
    "torch.set_default_dtype(torch.float64)\n",
    "%load_ext autoreload\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "device "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# dim=6\n",
    "state_min = torch.tensor([0, -torch.pi, -0.02, 0]).to(device) #[theta, theta_dot, py, target_ori]\n",
    "state_max = torch.tensor([torch.pi/2, torch.pi, 0.02, torch.pi/2]).to(device)\n",
    "\n",
    "n_state = 50\n",
    "n_action = 50\n",
    "mass = 0.16\n",
    "g = 9.81\n",
    "u_ps = 0.3\n",
    "\n",
    "max_normal_force = 4* mass*g\n",
    "\n",
    "action_min = torch.tensor([0, -u_ps*max_normal_force, 0]) #[f_1, f_2, py_dot]\n",
    "action_max = torch.tensor([max_normal_force, u_ps*max_normal_force, 0.05])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dt=0.001\n",
    "dyn_system =  pwall_env(state_min, state_max, dt=dt, device=device)\n",
    "\n",
    "# p_target = torch.tensor([dyn_system.l2-dyn_system.l1, dyn_system.l1-dyn_system.l2, 3*torch.pi/8, 0, 0,0,0, 0]).view(1,-1).to(device)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "domain_state = [torch.linspace(state_min[i],state_max[i],n_state).to(device) for i in range(len(state_max))]\n",
    "domain_action = [torch.linspace(action_min[i],action_max[i],n_action).to(device) for i in range(len(action_max))]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_test = 100\n",
    "dim_state = len(domain_state)\n",
    "init_state = torch.empty((n_test,dim_state))\n",
    "# for i in range(dim_state):\n",
    "#     init_state[:,i] = state_min[i] + torch.rand(n_test).clip(0.25,0.75).to(device)*(state_max[i]-state_min[i])\n",
    "init_state[:,-1] = state_min[-1] + torch.rand(n_test).to(device)*(state_max[-1]-state_min[-1])\n",
    "state = init_state.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def forward_model(state,action):\n",
    "    next_state = dyn_system.forward_simulate(state,action,dt)\n",
    "    return next_state\n",
    "\n",
    "\n",
    "def reward(state,action):\n",
    "    rewards = dyn_system.reward_state_action(state,action)\n",
    "    return rewards"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# tol = torch.tensor([0.03, 0.03, 15/180*torch.pi]).to(device)[:3]\n",
    "\n",
    "def callback(ttdp, state=state, file_name='fig',callback_count=0):\n",
    "    print(\"Testing....\")\n",
    "\n",
    "    history = []\n",
    "    T = 1000\n",
    "    traj = state[:,:].clone()[:,None,:] \n",
    "    traj_actions = torch.empty(state.shape[0],T,3).to(device) #bsxTx3\n",
    "    cum_reward = torch.tensor([0.]*state.shape[0]).to(device)\n",
    "    dt_cum = 0\n",
    "    for i in range(T):\n",
    "        t0 = time.time()\n",
    "        action = ttdp.policy(state)\n",
    "        t1=time.time()\n",
    "        dt_cum+=(t1-t0)\n",
    "        r = ttdp.reward_normalized(state,action)\n",
    "        cum_reward+=r\n",
    "        state = forward_model(state,action)\n",
    "        traj = torch.concat((traj,state[:,None,:]),dim=1)\n",
    "        traj_actions[:,i,:]=action\n",
    "    print(\"time taken by policy: \", dt_cum/T)\n",
    "\n",
    "\n",
    "    plot_num = 1 # number of plotted tasks\n",
    "    plt=dyn_system.animate_pivot(traj[-plot_num:].to('cpu').numpy(),traj_actions[-plot_num:].to('cpu').numpy(),\n",
    "                      animation=True, step_skip=10, xmax=0.3,figsize=5, scale=10)\n",
    "\n",
    "    plt.show()\n",
    "    return r.mean().to(\"cpu\"), cum_reward.mean().to(\"cpu\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ttpi = TTPI(domain_state=domain_state,\n",
    "                domain_action=domain_action,\n",
    "                reward=reward,\n",
    "                normalize_reward=True,\n",
    "                forward_model=forward_model,\n",
    "                gamma=0.99,\n",
    "                rmax_v=100, rmax_a=100,\n",
    "                nswp_v=5, nswp_a=5,\n",
    "                kickrank_v=10, kickrank_a=10,\n",
    "                max_batch_v=10**4,max_batch_a=10**5,\n",
    "                eps_cross_v=1e-3,\n",
    "                eps_cross_a=1e-3,\n",
    "                eps_round_v=1e-3,\n",
    "                eps_round_a=1e-3,\n",
    "                n_samples=50,\n",
    "                verbose=True,\n",
    "                device=device) # action = 'deterministic_tt', 'stochastic_tt', 'random'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "resume= False\n",
    "ttpi.train(resume=resume,\n",
    "        callback=callback, callback_freq=10,\n",
    "        verbose=False, file_name='pivoting')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  },
  "vscode": {
   "interpreter": {
    "hash": "06d9862aae01c2e6b250b7ba50e856fe4781a4d8639925b199bc13e1b7bc5401"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
