{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import sys\n",
    "sys.path.append('../')\n",
    "import time\n",
    "from ttpi import TTPI \n",
    "from dynamic_systems import HardMove\n",
    "torch.set_default_dtype(torch.float64)\n",
    "%load_ext autoreload\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "device "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dt=0.01\n",
    "dyn_system =  HardMove(dt=dt,w_goal=1e3, w_action=1e4, device=device)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def forward_model(state,action):\n",
    "    next_state = dyn_system.forward_simulate(state,action)\n",
    "    return next_state\n",
    "\n",
    "\n",
    "def reward(state,action):\n",
    "    rewards = dyn_system.reward_state_action(state, action)\n",
    "    return rewards"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "L = 1.\n",
    "position_max=L\n",
    "position_min= -1*L\n",
    "\n",
    "n_state = 50\n",
    "n_actuator = 12\n",
    "n_action = 100\n",
    "\n",
    "velocity_max = 0.25*position_max\n",
    "velocity_min = 0*velocity_max\n",
    "acc_max = 1*velocity_max\n",
    "acc_min = -1*acc_max\n",
    "\n",
    "domain_acc0 = torch.linspace(acc_min,acc_max,n_action).to(device)\n",
    "\n",
    "domain_switch0 = torch.arange(2).to(device)\n",
    "\n",
    "# dim=4\n",
    "state_min = torch.tensor([position_min, position_min, velocity_min, velocity_min]).to(device) #[x, y, x_dot, y_dot]\n",
    "state_max = torch.tensor([position_max, position_max, velocity_max, velocity_max]).to(device)\n",
    "\n",
    "\n",
    "domain_state = [torch.linspace(state_min[i],state_max[i],n_state).to(device) for i in range(len(state_max))]  \n",
    "domain_action = [domain_acc0,domain_switch0]*n_actuator\n",
    "\n",
    "action_max = acc_max\n",
    "action_min = -acc_max"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "len(domain_state)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x_obst = []\n",
    "r_obst = []"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_test = 20\n",
    "dim_state = len(domain_state)\n",
    "init_state = torch.empty((n_test,dim_state))\n",
    "for i in range(dim_state):\n",
    "    init_state[:,i] = state_min[i] + torch.rand(n_test).clip(0.25,0.75).to(device)*(state_max[i]-state_min[i])\n",
    "state = init_state.to(device)\n",
    "state[:, 2:4] = 0* state[:, 2:4]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def callback(ttpi,state=state,file_name='fig', callback_count=0):\n",
    "    N = 1\n",
    "    S = torch.empty(N); Mu = torch.empty(N);\n",
    "    history = []\n",
    "    T = int(10/dt)\n",
    "    traj = state[:,:2].clone()[:,None,:]\n",
    "    cum_reward = torch.tensor([0.]*state.shape[0]).to(device)\n",
    "    for i in range(T):\n",
    "        action =  ttpi.policy(state) #lqr_policy(state)#\n",
    "        r = dyn_system.reward_state_action(state,action)\n",
    "        cum_reward+=r\n",
    "        state = dyn_system.forward_simulate(state,action)\n",
    "        position = state[:,:2]\n",
    "        traj = torch.concat((traj,position[:,None,:2]),dim=1)\n",
    "    \n",
    "    from matplotlib import pyplot as plt\n",
    "    print(\"Cumulative reward: \", cum_reward.mean())\n",
    "    print(\"Avg: \", torch.mean(cum_reward))\n",
    "    from plot_utils import plot_point_mass\n",
    "    plt=plot_point_mass(traj.to('cpu'),x_target=torch.tensor([0,0]).to('cpu'), x_obst=[x.to('cpu') for x in x_obst], r_obst=r_obst, figsize=5)\n",
    "    plt.grid()\n",
    "    plt.show()\n",
    "    return r.mean().to(\"cpu\"),cum_reward.mean().to(\"cpu\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ttpi = TTPI(domain_state=domain_state, \n",
    "                domain_action=domain_action, \n",
    "                reward=reward, \n",
    "                normalize_reward=True,\n",
    "                forward_model=forward_model, \n",
    "                gamma=0.99, \n",
    "                rmax_v=100, rmax_a=100, \n",
    "                nswp_v=5, nswp_a=10, \n",
    "                kickrank_v=10, kickrank_a=10,\n",
    "                max_batch_v=10**4,max_batch_a=10**5,\n",
    "                eps_cross_v=1e-3,\n",
    "                eps_cross_a=1e-3,\n",
    "                eps_round_v=1e-3, \n",
    "                eps_round_a=1e-3, \n",
    "                n_samples=50, \n",
    "                verbose=True, \n",
    "                device=device) # action = 'deterministic_tt', 'stochastic_tt', 'random'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "resume= False\n",
    "ttpi.train(resume=resume,n_iter_v=1, \n",
    "        callback=callback, callback_freq=10,\n",
    "        verbose=False, file_name='HardMove')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
