{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "import custom_gym\n",
    "import matplotlib\n",
    "# matplotlib.use('Agg')\n",
    "%matplotlib inline\n",
    "\n",
    "from dynamic_systems_rl import PointMass\n",
    "\n",
    "from stable_baselines3.common.env_checker import check_env\n",
    "from stable_baselines3 import PPO, SAC, DDPG"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "order=1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "dt=0.01;\n",
    "x_obst = []#[torch.tensor([0.,-0.4])]#[torch.tensor([0.,-0.5]).to(device)] #[torch.tensor([0.35,0.5]).to(device),torch.tensor([0.35,-0.3]).to(device),torch.tensor([-0.1,-0.55]).to(device)]\n",
    "r_obst = []#[0.2]#[0.2] #[0.2]*len(x_obst)\n",
    "agent = PointMass(order=order, dt=dt, dim=2,\n",
    "                x_obst=x_obst,r_obst=r_obst,\n",
    "                w_obst=0., w_action= 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if order==1:\n",
    "    reset_bound = np.array([1,1]) \n",
    "else:\n",
    "    reset_bound = np.array([1,1,0.25,0.25]) \n",
    "env = custom_gym.CreateGymEnv(agent, \n",
    "                              reset_bound=reset_bound,     \n",
    "                              dt=dt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# It will check your custom environment and output additional warnings if needed\n",
    "check_env(env)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Instantiate the agent and Train the agent\n",
    "policy_kwargs = dict(net_arch=[32, 32])\n",
    "\n",
    "model = SAC(\n",
    "    \"MlpPolicy\",\n",
    "    env,\n",
    "    gamma=0.99,\n",
    "    use_sde=True,\n",
    "    sde_sample_freq=4,\n",
    "    learning_rate=1e-3,\n",
    "    verbose=1,\n",
    "    policy_kwargs=policy_kwargs\n",
    ")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def policy(state):\n",
    "    action, _ = model.predict(state, deterministic=True)\n",
    "    return action\n",
    "\n",
    "def callback(agent,policy, dt, order=order, dim=2, device='cpu',plt=False):\n",
    "    print(\"Testing....\")\n",
    "    state = 2*(-0.5+torch.rand((100,dim*order)))\n",
    "    history = []\n",
    "    traj = state[:,:2].clone()[:,None,:]\n",
    "    cum_reward = torch.tensor([0.]*state.shape[0]).to(device)\n",
    "    for i in range(5000):\n",
    "        action =  torch.from_numpy(policy(state)) #lqr_policy(state)#\n",
    "        r = agent.reward_state_action(state,action)\n",
    "        cum_reward+=r\n",
    "        state = agent.forward_simulate(state,action,dt)\n",
    "        position = state[:,:2]\n",
    "        traj = torch.concat((traj,position[:,None,:2]),dim=1)\n",
    "    \n",
    "    dist_straight = torch.linalg.norm(traj[:,0,:]-traj[:,-1,:],dim=-1).view(-1,1)\n",
    "    d_traj = torch.linalg.norm(traj[:,1:,:]-traj[:,:-1,:],dim=-1).sum(dim=-1).view(-1,1)\n",
    "    mu_i = (dist_straight/(1e-6+d_traj))\n",
    "    \n",
    "    dist_final = torch.linalg.norm(traj[:,-1,:],dim=-1).view(-1)\n",
    "    is_success = (dist_final<0.1)\n",
    "    success_rate = torch.sum(is_success)/len(is_success)\n",
    "    mu = ((mu_i.view(-1)*is_success).sum()/(1e-9+torch.sum(is_success)))**2\n",
    "    print(\"distance metric: \", mu)\n",
    "    print(\"sucess_rate: \", success_rate)\n",
    "    if plt:\n",
    "        from matplotlib import pyplot as plt\n",
    "        from plot_utils import plot_point_mass\n",
    "        plt=plot_point_mass(traj.to('cpu'),x_target=torch.tensor([0,0]).to('cpu'), x_obst=[x.to('cpu') for x in x_obst], r_obst=r_obst, figsize=5)\n",
    "        plt.show()\n",
    "    out = (success_rate, mu, cum_reward.mean().to(\"cpu\"))\n",
    "    return out\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "T = 0\n",
    "for i in range(10):\n",
    "    t1 = time.time()\n",
    "    model.learn(total_timesteps=int(1e5))\n",
    "    t2 = time.time()\n",
    "    T = T+ (t2-t1)\n",
    "    print(\"Time Taken Per Iteration: \", t2-t1)\n",
    "    print(\"Total time: \", T)\n",
    "    model.save('pm_acc_sac'+str(i))\n",
    "    \n",
    "    N=5\n",
    "    S = torch.empty(N); Mu=  torch.empty(N); CumR =   torch.empty(N);\n",
    "    for j in range(5):\n",
    "        plt_ = False if j<(N-1) else True\n",
    "        s,mu, r_cum= callback(agent=agent,policy=policy, \n",
    "                              order=order, dim=2, dt=dt, plt=plt_)\n",
    "        S[j] = s; Mu[j] = mu; CumR[j] = r_cum;\n",
    "    print(\"mu, mean:{}, std:{}\".format(torch.mean(Mu),torch.std(Mu)))\n",
    "    print(\"s, mean:{}, std:{}\".format(torch.mean(S),torch.std(S)))\n",
    "    print(\"cum_r, mean:{}, std:{}\".format(torch.mean(CumR),torch.std(CumR)))\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
