{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import sys\n",
    "sys.path.append('../')\n",
    "from ttpi import TTPI\n",
    "import numpy as np\n",
    "np.set_printoptions(precision=2)\n",
    "torch.set_printoptions(precision=2)\n",
    "torch.set_default_dtype(torch.float64)\n",
    "\n",
    "from dynamic_systems import PointMass\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "device"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "order = 1 # 1: 'velocity control', 2: 'acceleration control'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dim = 2 # 2D point-mass\n",
    "L = 1.0\n",
    "position_max = torch.tensor([L]*dim).to(device)\n",
    "position_min= -1*position_max\n",
    "\n",
    "n_state = 50\n",
    "n_action = 100\n",
    "\n",
    "dt = 0.01 # time step\n",
    "\n",
    "velocity_max = position_max/4\n",
    "velocity_min = -1*velocity_max\n",
    "\n",
    "acceleration_max = 1*velocity_max\n",
    "acceleration_min = -1*acceleration_max\n",
    "\n",
    "if order == 1:\n",
    "    state_min = position_min \n",
    "    state_max = position_max\n",
    "    action_min = velocity_min\n",
    "    action_max = velocity_max\n",
    "\n",
    "else:\n",
    "    state_min = torch.concat((position_min,velocity_min),dim=-1)\n",
    "    state_max = torch.concat((position_max,velocity_max),dim=-1)\n",
    "    action_min = acceleration_min\n",
    "    action_max = acceleration_max\n",
    "\n",
    "domain_state = []\n",
    "for i in range(len(state_max)):#x,dx\n",
    "    domain_state.append(torch.linspace(state_min[i],state_max[i],n_state).to(device))\n",
    "\n",
    "\n",
    "domain_action = []\n",
    "for i in range(len(action_max)):#x,dx\n",
    "    domain_action.append(torch.linspace(action_min[i],action_max[i],n_state).to(device))\n",
    "\n",
    "x_obst = [torch.tensor([0.,-0.4]).to(device)]#[torch.tensor([0.,-0.5]).to(device)] #[torch.tensor([0.35,0.5]).to(device),torch.tensor([0.35,-0.3]).to(device),torch.tensor([-0.1,-0.55]).to(device)]\n",
    "r_obst = [0.2] #[0.2]*len(x_obst)\n",
    "\n",
    "sys = PointMass(order=order, dt=dt, dim=dim,\n",
    "                x_obst=x_obst,r_obst=r_obst,\n",
    "                w_obst=1e2, w_action=5e2, w_goal=1e2, w_scale=1, device=device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def forward_model(state,action):\n",
    "    return sys.forward_simulate(state,action)\n",
    "\n",
    "def reward(state,action):\n",
    "    rewards, _ = sys.reward_state_action(state,action)\n",
    "    return rewards"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ttdp = None # defined in the later cell\n",
    "def callback(ttpi,file_name='fig', callback_count=0):\n",
    "    print(\"Testing....\")\n",
    "    state = torch.tensor([[0.,0.],\n",
    "                          [-0.3,-0.],[-0.3,0.3],[-0.3,-0.3],\n",
    "                          [0.3,-0.],[0.3,0.3],[0.3,-0.3],\n",
    "                          [0,0.3],\n",
    "                          [-0.75,-0.75],[-0.75,0.75],[0.75,-0.75],\n",
    "                          [-0.6,-0.],[-0.6,0.2],[-0.5,-0.6],\n",
    "                          [0.6,-0.2],[0.6,0.5],[0.6,-0.7],\n",
    "                          [0.1,0.6],\n",
    "                          [-0.8,-0.],[-0.8,0.7],[-0.8,-0.5],\n",
    "                          [0.8,-0.],[0.,-0.9],[0.8,0.4],[0.8,-0.6],[0.9,0],\n",
    "                          [0.1,0.8], [-0.1,-0.8], [0.1,-0.8], [0.2,-0.8]]).to(device)\n",
    "    state = state\n",
    "    if dim>2:\n",
    "        state_append = torch.tensor([0.]*(dim-2)).to(device).view(1,-1).expand(state.shape[0],-1)\n",
    "        state = torch.cat((state,state_append),dim=-1)\n",
    "    if order==2:\n",
    "        state = torch.cat((state,state*0),dim=-1) \n",
    "    history = []\n",
    "    traj = state[:,:2].clone()[:,None,:]\n",
    "    cum_reward = torch.tensor([0.]*state.shape[0]).to(device)\n",
    "    for i in range(1000):\n",
    "#         set_trace()\n",
    "        action = ttpi.policy(state)\n",
    "        r = ttpi.reward_normalized(state,action)\n",
    "        cum_reward+=r\n",
    "        state = forward_model(state,action)\n",
    "        position = state[:,:dim]\n",
    "        traj = torch.concat((traj,position[:,None,:2]),dim=1)\n",
    "    print(\"Cumulative reward: \", cum_reward.mean())\n",
    "    print(\"Avg: \", torch.mean(cum_reward))\n",
    "    from plot_utils import plot_point_mass\n",
    "    plt=plot_point_mass(traj.to('cpu'),\n",
    "        xmax=L,\n",
    "        x_target=torch.tensor([0,0]).to('cpu'), \n",
    "        x_obst=[x.to('cpu') for x in x_obst], r_obst=r_obst, \n",
    "        save_as = None,\n",
    "        figsize=5)    \n",
    "    plt.show()\n",
    "    \n",
    "    return r.mean().to(\"cpu\"),cum_reward.mean().to(\"cpu\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# callback(ttdp,file_name='fig')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ttpi = TTPI(domain_state=domain_state, domain_action=domain_action, reward=reward, \n",
    "                forward_model=forward_model, \n",
    "                gamma=0.99,\n",
    "                rmax_v=100, rmax_a=100, \n",
    "                nswp_v=10, nswp_a=10, \n",
    "                kickrank_v=5, kickrank_a=10,\n",
    "                max_batch_v=10**4, max_batch_a=10**5,\n",
    "                eps_cross_v=1e-3,eps_cross_a=1e-3,\n",
    "                eps_round_v=1e-3,eps_round_a=1e-3, \n",
    "                n_samples=50,normalize_reward=False,\n",
    "                verbose=True,\n",
    "                device=device) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "resume= False # resume=True => resume a previous training \n",
    "ttpi.train(n_iter_max=200,resume=resume, \n",
    "        callback=callback, callback_freq=20,\n",
    "        verbose=False, file_name='point_mass')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  },
  "vscode": {
   "interpreter": {
    "hash": "cf96f6c213ba3f9333b362e3bb271376c1f8feeec3b85b92580d68346ee16de3"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
