{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "from os.path import abspath, dirname\n",
    "sys.path.insert(0, \"..\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from __future__ import absolute_import\n",
    "from __future__ import division\n",
    "from __future__ import print_function\n",
    "\n",
    "import argparse\n",
    "import os\n",
    "import pprint\n",
    "from dotmap import DotMap\n",
    "from MBExperiment import MBExperiment\n",
    "from MPC import MPC\n",
    "from config import create_config\n",
    "import env # We run this so that the env is registered\n",
    "import torch\n",
    "import numpy as np\n",
    "import random\n",
    "import tensorflow as tf\n",
    "import tqdm\n",
    "from config.cartpole import CartpoleConfigModule\n",
    "import gym\n",
    "import pandas as pd\n",
    "import os\n",
    "from time import localtime, strftime\n",
    "from dotmap import DotMap\n",
    "from scipy.io import savemat\n",
    "from tqdm import trange\n",
    "from Agent import Agent\n",
    "from DotmapUtils import get_required_argument\n",
    "from utils import to_json, read_json, to_pickle, read_pickle\n",
    "import matplotlib\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import pandas as pd\n",
    "import glob as glob\n",
    "import sys\n",
    "import os\n",
    "from tqdm import tqdm\n",
    "import copy\n",
    "import pickle\n",
    "# import mpld3\n",
    "# mpld3.enable_notebook()\n",
    "import utils\n",
    "import time\n",
    "TORCH_DEVICE = utils.TORCH_DEVICE\n",
    "VARIATION_NOISE = 1e-12\n",
    "\n",
    "def set_global_seeds(seed):\n",
    "    torch.manual_seed(seed)\n",
    "    if torch.cuda.is_available():\n",
    "        torch.cuda.manual_seed_all(seed)\n",
    "\n",
    "    np.random.seed(seed)\n",
    "    random.seed(seed)\n",
    "\n",
    "    tf.set_random_seed(seed)\n",
    "\n",
    "def print_gpu_memory():\n",
    "    print('Total memory: %.3f GB. Memory allocated: %.5f GB. Max allocated: %.5f GB ' % (torch.cuda.get_device_properties(0).total_memory / 1e9,\n",
    "                                                          torch.cuda.memory_allocated(0) / 1e9, torch.cuda.max_memory_allocated() / 1e9 ))\n",
    "print_gpu_memory()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def rmse(predictions, targets):\n",
    "    if isinstance(obs, np.ndarray):\n",
    "        return np.sqrt(((predictions - targets) ** 2).mean())\n",
    "    else:\n",
    "        return torch.sqrt(((predictions - targets) ** 2).mean())  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# cartople\n",
    "env = 'pusher' #cartpole, pusher, reacher, halfcheetah\n",
    "#env = 'pointbot'\n",
    "#env = 'halfcheetah'\n",
    "\n",
    "exp_name = '09_test_%s' % env\n",
    "cmd_line_args = ['python', '8_online.py',\n",
    "    '-env', env,\n",
    "    '-logdir', exp_name,\n",
    "    '--METHOD', 'BASELINE',\n",
    "    '--S_FUT_KL_CST', '2',\n",
    "    '--S_FUT_LASTEPS', '10',\n",
    "    '--MAX_BUFFER_LENGTH', '100000',\n",
    "    '--COLD_START_STEPS', '0',\n",
    "    '--MANEUVER', 'sector_1',\n",
    "    '--SAVE'\n",
    "    ]\n",
    "cmd_line_args = list(filter(None, cmd_line_args)) # remove None, otherwise argparse crash\n",
    "cmd_line_args = cmd_line_args[2:]\n",
    "cmd_line_args\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# # cartople\n",
    "# env = 'reacher' #cartpole, pusher, reacher, halfcheetah\n",
    "# #env = 'pointbot'\n",
    "# #env = 'halfcheetah'\n",
    "\n",
    "# exp_name = '09_test_%s' % env\n",
    "# cmd_line_args = ['python', '8_online.py',\n",
    "#     '-env', env,\n",
    "#     '-logdir', exp_name,\n",
    "#     '--METHOD', 'BASELINE',\n",
    "#     '--S_FUT_KL_CST', '8',\n",
    "#     '--S_FUT_LASTEPS', '20',\n",
    "#     '--MAX_BUFFER_LENGTH', '100000',\n",
    "#     '--COLD_START_STEPS', '0',\n",
    "#     '--MANEUVER', 'sector_1',\n",
    "#     '--SAVE'\n",
    "#     ]\n",
    "# cmd_line_args = list(filter(None, cmd_line_args)) # remove None, otherwise argparse crash\n",
    "# cmd_line_args = cmd_line_args[2:]\n",
    "# cmd_line_args\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# # cartople\n",
    "# env = 'cartpole' #cartpole, pusher, reacher, halfcheetah\n",
    "# #env = 'pointbot'\n",
    "# #env = 'halfcheetah'\n",
    "\n",
    "# exp_name = '09_test_%s' % env\n",
    "# cmd_line_args = ['python', '8_online.py',\n",
    "#     '-env', env,\n",
    "#     '-logdir', exp_name,\n",
    "#     '--METHOD', 'BASELINE',\n",
    "#     '--S_FUT_KL_CST', '32',\n",
    "#     '--S_FUT_LASTEPS', '20',\n",
    "#     '--MAX_BUFFER_LENGTH', '100000',\n",
    "#     '--COLD_START_STEPS', '0',\n",
    "#     '--MANEUVER', 'sector_1',\n",
    "#     '--SAVE'\n",
    "#     ]\n",
    "# cmd_line_args = list(filter(None, cmd_line_args)) # remove None, otherwise argparse crash\n",
    "# cmd_line_args = cmd_line_args[2:]\n",
    "# cmd_line_args"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "parser = argparse.ArgumentParser()\n",
    "parser.add_argument('-env', type=str, required=True,\n",
    "                    help='Environment name: select from [cartpole, reacher, pusher, halfcheetah]')\n",
    "parser.add_argument('-ca', '--ctrl_arg', action='append', nargs=2, default=[],\n",
    "                    help='Controller arguments, see https://github.com/kchua/handful-of-trials#controller-arguments')\n",
    "parser.add_argument('-o', '--override', action='append', nargs=2, default=[],\n",
    "                    help='Override default parameters, see https://github.com/kchua/handful-of-trials#overrides')\n",
    "parser.add_argument('-logdir', type=str, default='log',\n",
    "                    help='Directory to which results will be logged (default: ./log)')\n",
    "\n",
    "# enables\n",
    "parser.add_argument('--METHOD', type=str, required = True, help=\"Experiment method: UARF, BICHO, or BASELINE\")\n",
    "parser.add_argument('--MANEUVER', type=str, default = 'straight')\n",
    "\n",
    "# misc\n",
    "parser.add_argument('--SAVE', action='store_true')\n",
    "parser.add_argument('--LOAD', type=str)\n",
    "\n",
    "\n",
    "parser.add_argument('--COLD_START_STEPS', default=0, type=float)\n",
    "parser.add_argument('--NEW_DATA_TRAIN_THRESHOLD', default=0.01, type=float)\n",
    "parser.add_argument('--MAX_BUFFER_LENGTH', default=None, type=float)\n",
    "\n",
    "## FUT\n",
    "parser.add_argument('--S_FUT_LASTEPS', default=20, type=int)\n",
    "\n",
    "# KL\n",
    "parser.add_argument('--S_FUT_KL_CST', default=32, type=float)\n",
    "parser.add_argument('--S_FUT_E_INPUT', default=\"cost\", type=str, help='[obs, cost]')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "### cmd_line_args to\n",
    "args = parser.parse_args(cmd_line_args)\n",
    "#args = parser.parse_args()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "output_path = os.path.join(args.logdir, \"\")\n",
    "os.makedirs(output_path, exist_ok=True)\n",
    "output_path_by_episode = os.path.join(args.logdir, \"episodes\")\n",
    "os.makedirs(output_path_by_episode, exist_ok=True)\n",
    "\n",
    "\n",
    "''' Enable '''\n",
    "\n",
    "S_FUT_LASTEPS = args.S_FUT_LASTEPS\n",
    "\n",
    "METHOD = args.METHOD\n",
    "\n",
    "# KL FUT\n",
    "S_FUT_KL_CST = args.S_FUT_KL_CST\n",
    "S_FUT_E_INPUT = args.S_FUT_E_INPUT\n",
    "COLD_START_STEPS = args.COLD_START_STEPS\n",
    "NEW_DATA_TRAIN_THRESHOLD = args.NEW_DATA_TRAIN_THRESHOLD\n",
    "MAX_PREDICTION_DISTANCE = 10e6\n",
    "MAX_BUFFER_LENGTH = args.MAX_BUFFER_LENGTH\n",
    "env = args.env\n",
    "\n",
    "ctrl_type = 'MPC'\n",
    "ctrl_args = []\n",
    "overrides = []\n",
    "logdir = os.path.join(args.logdir, \"buffers\")\n",
    "os.makedirs(logdir, exist_ok=True)\n",
    "overrides = args.override\n",
    "\n",
    "\n",
    "ctrl_args = DotMap(**{key: val for (key, val) in ctrl_args})\n",
    "cfg = create_config(env, ctrl_type, ctrl_args, overrides, logdir)\n",
    "cfg.pprint()\n",
    "\n",
    "assert ctrl_type == 'MPC'\n",
    "\n",
    "\n",
    "# overwrites\n",
    "cfg.ctrl_cfg.per = 10\n",
    "cfg.exp_cfg.exp_cfg.policy = MPC(cfg.ctrl_cfg)\n",
    "exp = MBExperiment(cfg.exp_cfg)\n",
    "\n",
    "with open(os.path.join(exp.logdir, \"config.txt\"), \"w\") as f:\n",
    "    f.write(pprint.pformat(cfg.toDict()))\n",
    "\n",
    "maneuver_index = -1\n",
    "self = exp\n",
    "if env == \"pointbot\":\n",
    "    '''\n",
    "    c1, c1_i # 100\n",
    "    c14, c14_i, 100\n",
    "    straight, # 100\n",
    "    sector_1, # 290 steps\n",
    "    sector_1_i, # 290 steps\n",
    "    chicane, # 50 steps\n",
    "    chicane_i # 50 steps\n",
    "    '''\n",
    "    maneuver_lengths = {\n",
    "        \"c1\" :          100,\n",
    "        \"c1_i\" :          100,\n",
    "        \"chicane\" :    50,\n",
    "        \"chicane_i\" :    50,\n",
    "        \"c14\" :        100,\n",
    "        \"c14_i\" :        100,\n",
    "        \"straight\" :      100,\n",
    "        \"sector_1\" :      290,\n",
    "        \"sector_1_i\" :  290,\n",
    "        \"full_track_barcelona\": 1200\n",
    "\n",
    "    }\n",
    "    self.task_hor = maneuver_lengths[args.MANEUVER]\n",
    "    self.env.set_maneuver(args.MANEUVER, TORCH_DEVICE)\n",
    "    maneuvers = list(maneuver_lengths.keys())\n",
    "    maneuver_index = maneuvers.index(args.MANEUVER)\n",
    "\n",
    "\n",
    "# # Exp\n",
    "print(\"*\"  * 80)\n",
    "print(\"### Starting\")\n",
    "print(\"output_path\", output_path)\n",
    "print(\"env\", env)\n",
    "print(\"task hor:\", self.task_hor)\n",
    "print(\"self.ntrain_iters:\", self.ntrain_iters)\n",
    "\n",
    "print(\"Enables:\")\n",
    "\n",
    "print(\"FUT:\")\n",
    "print(\"S_FUT_LASTEPS\", S_FUT_LASTEPS)\n",
    "print('S_FUT_KL_CST', S_FUT_KL_CST)\n",
    "print(\"*\"  * 80)\n",
    "\n",
    "os.makedirs(self.logdir, exist_ok=True)\n",
    "self.logdir\n",
    "\n",
    "'''\n",
    "New version: simplified\n",
    "'''\n",
    "def act_dyn(self, obs, t, recalculate=True, get_pred_cost=False):\n",
    "    \"\"\"Returns the action that this controller would take at time t given observation obs.\n",
    "\n",
    "    Arguments:\n",
    "        obs: The current observation\n",
    "        t: The current timestep\n",
    "        get_pred_cost: If True, returns the predicted cost for the action sequence found by\n",
    "            the internal optimizer.\n",
    "\n",
    "    Returns: An action (and possibly the predicted cost)\n",
    "    \"\"\"\n",
    "\n",
    "    #print(\"# step %d\" % t)\n",
    "    if not self.has_been_trained:\n",
    "        return np.random.uniform(self.ac_lb, self.ac_ub, self.ac_lb.shape)\n",
    "\n",
    "    #print(\"# Shoud re calculate\")\n",
    "    # set the current observation to use it later in the compile_cost function\n",
    "    #self.sy_cur_obs = obs\n",
    "\n",
    "    if recalculate:\n",
    "        self.sy_cur_obs = obs\n",
    "\n",
    "        # returns array (horizon,) -> with the best actions found\n",
    "        #   start with the previous means (solutions) and always the initial variance\n",
    "        recalculate_time = time.time()\n",
    "        soln = self.optimizer.obtain_solution(self.prev_sol, self.init_var)\n",
    "        print(f\"recalculation time: {time.time() - recalculate_time}\")\n",
    "        '''\n",
    "        if get_traj:\n",
    "            obs, var = get_traj(init_obs, soln)\n",
    "            traj = sorn, obs, var (25, 26, 26)\n",
    "        '''\n",
    "        # dU -> actions dimension\n",
    "        #\n",
    "        self.prev_sol_full = np.copy(soln) # (25,)\n",
    "        self.prev_sol = np.copy(soln) # (25,)\n",
    "        self.prev_sol_obs = np.copy(obs)\n",
    "\n",
    "    #action = np.copy( self.prev_sol[0] )\n",
    "    self.prev_sol_with_action = self.prev_sol.copy()\n",
    "\n",
    "    action = self.prev_sol[:1 * self.dU].reshape(-1, self.dU)  # gives back one action (1,1) -> first element in the queue\n",
    "    self.prev_sol = np.concatenate([np.copy(self.prev_sol)[1 * self.dU:], np.zeros(1 * self.dU)]) # set last to zero\n",
    "    return action\n",
    "\n",
    "def project_trajectory(self, obs, actions, start_index):\n",
    "    if not self.has_been_trained:\n",
    "        return np.random.uniform(self.ac_lb, self.ac_ub, self.ac_lb.shape)\n",
    "\n",
    "    self.sy_cur_obs = obs\n",
    "\n",
    "    actions = np.concatenate([np.copy(actions)[start_index * self.dU:], np.zeros(start_index * self.dU)]) # set last to zero\n",
    "\n",
    "    # get predictions\n",
    "    a = np.copy(actions).reshape(1,-1)\n",
    "    cost, traj = self._compile_cost(a, return_seq=True)\n",
    "    obs_mean = np.array( traj['traj_cur_obs'] ).mean(axis=1)  # (25,1) , the mean is to collapse the 20 particles\n",
    "    obs_std  = np.array( traj['traj_cur_obs'] ).std(axis=1) # (25,1) , the mean is to collapse the 20 particles\n",
    "    cost_mean  = np.array( traj['traj_next_cost'] ).mean(axis=1)  # (25,1) , the mean is to collapse the 20 particles\n",
    "    cost_std  = np.array( traj['traj_next_cost'] ).std(axis=1)  # (25,1) , the mean is to collapse the 20 particles\n",
    "\n",
    "    return (obs_mean.copy(), obs_std.copy(), cost_mean.copy(), cost_std.copy())\n",
    "\n",
    "#\n",
    "#  load model\n",
    "#\n",
    "if env==\"pointbot\":\n",
    "    self.initialize_model(maneuver_index)\n",
    "else:\n",
    "    self.initialize_model(-1)\n",
    "\n",
    "# model_path = r'results\\experiments\\CP_200\\pre_train\\2020-12-01--133433'\n",
    "# self.load_model(model_path)\n",
    "\n",
    "# print(\"Validation step...\")\n",
    "# ret = sample(self.agent, self.task_hor, self.policy, record_fname=None, render=False)\n",
    "# Get stats to do the scalling\n",
    "if len( self.policy.train_in ) < 500:\n",
    "    train_in = self.policy.train_in.copy()\n",
    "    train_targs = self.policy.train_targs.copy()\n",
    "else:\n",
    "    train_in = self.policy.train_in[500:].copy()\n",
    "    train_targs = self.policy.train_targs[500:].copy()\n",
    "\n",
    "train_targs_mu = np.mean(train_targs, axis=0, keepdims=True)\n",
    "train_targs_sigma = np.std(train_targs, axis=0, keepdims=True)\n",
    "train_targs_sigma[train_targs_sigma < 1e-12] = 1.0\n",
    "\n",
    "def sample_P(self, horizon, verbose=False):\n",
    "    run_start = time.time()\n",
    "\n",
    "    policy = self.policy\n",
    "\n",
    "    times, rewards = [], []\n",
    "    errors = []\n",
    "    O, A, reward_sum, done = [self.env.reset()], [], 0, False\n",
    "\n",
    "    policy.reset()\n",
    "    episode_info = []\n",
    "    skip_step = 0 # counter since recalculation\n",
    "    skip_t = 0\n",
    "\n",
    "    #for t in range(10):\n",
    "    for t in range(horizon):\n",
    "        self.env.render()\n",
    "        recalculate = False\n",
    "        if METHOD != \"UARF\" or len(policy.train_in) < policy.cold_start_steps:\n",
    "            add_to_buffer = True\n",
    "        else:\n",
    "            add_to_buffer = False\n",
    "        if (t==0):\n",
    "            recalculate = True\n",
    "        else:\n",
    "            if kl_loss > S_FUT_KL_CST:\n",
    "                recalculate = True\n",
    "                if skip_t < MAX_PREDICTION_DISTANCE:\n",
    "                    add_to_buffer = True\n",
    "                print('# kl skip', S_FUT_KL_CST, kl_loss)\n",
    "\n",
    "        if (skip_t >= (self.policy.plan_hor - 2)):\n",
    "            \"\"\" If we used all the planned actions we need to recalculate \"\"\"\n",
    "            print('recalculate because of plan hor %d' % self.policy.plan_hor)\n",
    "            recalculate = True\n",
    "\n",
    "        if METHOD == 'BASELINE':\n",
    "            ''' disable skip steps by recalculating always '''\n",
    "            recalculate = True\n",
    "            add_to_buffer = True\n",
    "        a_t = act_dyn(policy, O[t], t, recalculate=recalculate)\n",
    "        A.append(a_t)\n",
    "\n",
    "        if recalculate:\n",
    "            obs_when_recalc = np.copy( O[t] )\n",
    "            skip_t = 0\n",
    "        else:\n",
    "            skip_t += 1\n",
    "\n",
    "        obs, reward, done, info = self.env.step(A[t])\n",
    "        if env == 'reacher':\n",
    "            print('goal:', self.env.goal)\n",
    "            print(np.sum(np.square(self.env.get_EE_pos(self.env._get_obs()[None]) - self.env.goal)))\n",
    "        \n",
    "        O.append(obs.copy())\n",
    "        reward_sum += reward\n",
    "        rewards.append(reward)\n",
    "\n",
    "        a = np.copy(policy.prev_sol_full)\n",
    "        obs_mean, obs_std, cost_mean, cost_std = project_trajectory(policy, obs_when_recalc, a, start_index=0)\n",
    "        cost_pred = cost_mean.copy()\n",
    "        # Get rest of the trajectory from the current obs and the rest of the A\n",
    "        a = np.copy(policy.prev_sol_full)\n",
    "        obs_mean_fut, obs_std_fut, cost_mean_fut, cost_std_fut = project_trajectory(policy, O[t+0], a, start_index=skip_t+0)\n",
    "\n",
    "        next_obs = torch.tensor( obs.reshape(1,-1), device=TORCH_DEVICE )\n",
    "        cur_acs = torch.tensor( a_t.reshape(1,-1), device=TORCH_DEVICE )\n",
    "        cost = policy.obs_cost_fn(next_obs) + policy.ac_cost_fn(cur_acs)\n",
    "        cost = cost.detach().cpu().numpy()\n",
    "\n",
    "        pred_obs = obs_mean[skip_t+1].copy()\n",
    "        pred_obs_std = obs_std[skip_t+1].copy()\n",
    "\n",
    "        # calculate error\n",
    "        e =(pred_obs[:] - obs[:])  #/ train_targs_sigma[0][1:] # skip possition and normalize\n",
    "\n",
    "        error_obs_mean = np.mean(e)\n",
    "        error_obs_euc_d = np.sqrt( np.sum(np.square(e)))\n",
    "\n",
    "\n",
    "        if S_FUT_E_INPUT == 'obs':\n",
    "            obs_mean = obs_mean[skip_t+1:][0:S_FUT_LASTEPS]\n",
    "            obs_std =  obs_std[skip_t+1:][0:S_FUT_LASTEPS]\n",
    "\n",
    "            obs_mean_fut = obs_mean_fut[skip_t+0:][0:S_FUT_LASTEPS]\n",
    "            obs_std_fut = obs_std_fut[skip_t+0:][0:S_FUT_LASTEPS]\n",
    "\n",
    "            # crop when we are reaching the end of the trajectory\n",
    "            crop = min(obs_mean.shape[0], S_FUT_LASTEPS)\n",
    "            obs_mean_fut = obs_mean_fut[:crop]\n",
    "            obs_std_fut = obs_std_fut[:crop]\n",
    "\n",
    "            dist_1 = torch.distributions.Normal(torch.tensor(obs_mean),\n",
    "                                                   torch.tensor(obs_std))\n",
    "\n",
    "            dist_2 = torch.distributions.Normal(torch.tensor(obs_mean_fut),\n",
    "                                                torch.tensor(obs_std_fut) )\n",
    "\n",
    "            kl_loss = torch.distributions.kl_divergence(dist_1, dist_2)#.mean(axis=1)\n",
    "            kl_loss = kl_loss.mean().numpy()\n",
    "        elif  S_FUT_E_INPUT == 'cost':\n",
    "            cost_mean = cost_mean[skip_t+0:][0:S_FUT_LASTEPS]\n",
    "            cost_std =  cost_std[skip_t+0:][0:S_FUT_LASTEPS] + VARIATION_NOISE\n",
    "\n",
    "            cost_mean_fut = cost_mean_fut[skip_t:][0:S_FUT_LASTEPS]\n",
    "            cost_std_fut = cost_std_fut[skip_t:][0:S_FUT_LASTEPS]\n",
    "\n",
    "            # crop when we are reaching the end of the trajectory\n",
    "            crop = min(cost_mean.shape[0], S_FUT_LASTEPS)\n",
    "            cost_mean_fut = cost_mean_fut[:crop]\n",
    "            cost_std_fut = cost_std_fut[:crop] + VARIATION_NOISE\n",
    "            dist_1 = torch.distributions.Normal(torch.tensor(cost_mean),\n",
    "                                                   torch.tensor(cost_std))\n",
    "            dist_2 = torch.distributions.Normal(torch.tensor(cost_mean_fut),\n",
    "                                                torch.tensor(cost_std_fut) )\n",
    "\n",
    "\n",
    "            kl_loss = torch.distributions.kl_divergence(dist_1, dist_2)#.mean(axis=1)\n",
    "            kl_loss = kl_loss.mean().numpy()\n",
    "\n",
    "            print('cost', kl_loss)\n",
    "            \n",
    "            print('cost_mean', cost_mean, 'std', cost_std)\n",
    "            print('cost_mean_fut', cost_mean_fut, 'cost_std_fut', cost_std_fut)\n",
    "\n",
    "            if np.isnan(kl_loss):\n",
    "                print(cost_mean)\n",
    "                print(cost_std)\n",
    "                print(cost_mean_fut)\n",
    "                print(cost_std_fut)\n",
    "                kl_loss = 1e9\n",
    "        else:\n",
    "            kl_loss = 0\n",
    "\n",
    "\n",
    "        #train_set_euc_dist_error_mean, train_set_euc_dist_error_std = e_disc.mean(), e_disc.std()\n",
    "        r = {'error_obs_mean': error_obs_mean.copy(),\n",
    "             'pred_cost': cost_pred[skip_t],\n",
    "             'sim_cost':cost[0],\n",
    "             'error_obs_euc_d':error_obs_euc_d,\n",
    "             'recalculate':float(recalculate),\n",
    "             'add_to_buffer':float(add_to_buffer),\n",
    "             'run_wall_time':time.time() - run_start,\n",
    "             'kl_loss':kl_loss,\n",
    "             'skip_t':skip_t\n",
    "            }\n",
    "        r.update( {\"sim_obs_%.2d\" % f:obs[f] for f in range( obs.shape[0] )} )\n",
    "        r.update( {\"pred_obs_%.2d\" % f:pred_obs[f] for f in range( pred_obs.shape[0] )} )\n",
    "\n",
    "        if verbose:\n",
    "            print(\"r_t:%.2f\" % reward, \"reward_sum:%.2f\" % reward_sum, \"%.2d %.2d\" % (t, skip_t), \"adding to buffer: %d\" % int(add_to_buffer),\n",
    "              \"erros_obs_euc: %.4f\" % (error_obs_euc_d))\n",
    "\n",
    "        episode_info.append(r.copy())\n",
    "        if done: break\n",
    "\n",
    "    df = pd.DataFrame(episode_info)\n",
    "    porc_recalc = (np.sum( df.recalculate ) / len( df.recalculate))\n",
    "    ret = {\n",
    "        \"obs\": np.array(O)[:-1],\n",
    "        \"obs_\": np.array(O)[1:],\n",
    "        \"ac\": np.array(A).reshape(-1,self.policy.dU),\n",
    "        \"reward_sum\": reward_sum,\n",
    "        \"rewards\": np.array(rewards),\n",
    "        \"porc_recalc\":porc_recalc,\n",
    "        \"add_to_buffer\":df[\"add_to_buffer\"].values,\n",
    "        \"recalculated\":df[\"recalculate\"].values\n",
    "    }\n",
    "    print(ret['reward_sum'], porc_recalc)\n",
    "    return ret\n",
    "\n",
    "''' skip steps '''\n",
    "#N_STEPS = cfg.ctrl_cfg.per\n",
    "self.policy.cold_start_steps = int(COLD_START_STEPS) if COLD_START_STEPS is not None else None\n",
    "self.policy.max_buffer_length = int(MAX_BUFFER_LENGTH) if MAX_BUFFER_LENGTH is not None else None\n",
    "self.policy.new_data_train_threshold = NEW_DATA_TRAIN_THRESHOLD\n",
    "self.policy.method = METHOD\n",
    "''' Steps alike '''\n",
    "S_ALIKE_C =2.0\n",
    "\n",
    "verbose = True\n",
    "ep_stats = []\n",
    "all_traj = []\n",
    "horizon = self.task_hor\n",
    "run_path = output_path\n",
    "\n",
    "if args.LOAD:\n",
    "    print(\"LOADING\\n\\n\\n\\n\\n\\n\\n\")\n",
    "    self.load_model(args.LOAD)\n",
    "    self.policy.has_been_trained = True\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# hack\n",
    "self.ntrain_iters = 100\n",
    "\n",
    "METHOD = 'BICHO'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "for i in trange(self.ntrain_iters):\n",
    "    samples = []\n",
    "    episode_start_time = time.time()\n",
    "    for j in range(max(self.neval, self.nrollouts_per_iter)):\n",
    "        s = sample_P(self, horizon=horizon, verbose=verbose)\n",
    "        samples.append(s)\n",
    "        all_traj.append(s.copy())\n",
    "\n",
    "    collection_time = time.time() - episode_start_time\n",
    "    ep_reward = [sample[\"reward_sum\"] for sample in samples[:self.neval]][0]\n",
    "    ep_perc_recalc = [sample[\"porc_recalc\"] for sample in samples[:self.neval]][0]\n",
    "    ep_step_rewards = [sample[\"rewards\"] for sample in samples[:self.neval]][0]\n",
    "    ep_recalculated = [sample[\"recalculated\"] for sample in samples[:self.neval]][0]\n",
    "    ep_added_to_buffer = [sample[\"add_to_buffer\"] for sample in samples[:self.neval]][0]\n",
    "\n",
    "\n",
    "    samples = samples[:self.nrollouts_per_iter]\n",
    "\n",
    "    print(\"Start training episode %d\" % i)\n",
    "\n",
    "    # Train\n",
    "    losses = []\n",
    "    train_time = time.time()\n",
    "    def pre_process_samples(samples):\n",
    "        for sample in samples:\n",
    "            delete_indices = []\n",
    "            for idx in range(len(sample[\"add_to_buffer\"])):\n",
    "                if sample[\"add_to_buffer\"][idx] == 0:\n",
    "                    delete_indices.append(idx)\n",
    "            sample[\"obs\"] = np.delete(sample[\"obs\"], delete_indices, 0)\n",
    "            sample[\"obs_\"] = np.delete(sample[\"obs_\"], delete_indices, 0)\n",
    "            sample[\"ac\"] = np.delete(sample[\"ac\"], delete_indices, 0)\n",
    "            sample[\"rewards\"] = np.delete(sample[\"rewards\"], delete_indices, 0)\n",
    "            sample[\"add_to_buffer\"] = np.delete(sample[\"add_to_buffer\"], delete_indices, 0)\n",
    "        return samples\n",
    "    filepath = file= args.logdir + os.sep + 'replay_buffer_no_filtering.pkl'\n",
    "    utils.to_pickle(obj=samples, file=filepath, verbose=True)\n",
    "\n",
    "    if METHOD == \"UARF\":\n",
    "        samples = pre_process_samples(samples)\n",
    "    if i < self.ntrain_iters:\n",
    "        l, trained, buffer_last_train = self.policy.train(\n",
    "            [sample[\"obs\"] for sample in samples],\n",
    "            [sample[\"ac\"] for sample in samples],\n",
    "            [sample[\"rewards\"] for sample in samples],\n",
    "            [sample[\"obs_\"] for sample in samples],\n",
    "            [np.full_like(sample[\"rewards\"], maneuver_index) for sample in samples] \n",
    "        )\n",
    "        if trained:\n",
    "            losses.append(l.copy())\n",
    "    buffer_size = self.policy.train_in.shape[0]\n",
    "    if args.SAVE:\n",
    "        self.save(train_iteration = i)\n",
    "    if METHOD == 'UARF':\n",
    "        MAX_PREDICTION_DISTANCE = int(1 / ep_perc_recalc) - 1\n",
    "    else:\n",
    "        MAX_PREDICTION_DISTANCE = 10e6\n",
    "\n",
    "\n",
    "\n",
    "    train_time = time.time() - train_time\n",
    "    losses_mean, losses_std = np.array(losses).mean(), np.array(losses).std()\n",
    "\n",
    "    by_step_stats = {\"Rewards\" : ep_step_rewards,\n",
    "                    \"Recalculated\" : ep_recalculated,\n",
    "                    \"Added to Buffer\" : ep_added_to_buffer}\n",
    "    by_step_df = pd.DataFrame(by_step_stats)\n",
    "    by_step_df.to_csv(output_path_by_episode + f\"/episode{i}.csv\")\n",
    "\n",
    "    ep_stat = {}\n",
    "    ep_stat['total_steps'] = (i+2) * len(ep_step_rewards)\n",
    "    ep_stat['ep_reward'] = ep_reward\n",
    "    ep_stat['buffer_size'] = buffer_size\n",
    "    ep_stat['losses_mean'] = losses_mean\n",
    "    ep_stat['losses_std'] = losses_std\n",
    "    ep_stat['perc_recalc'] = ep_perc_recalc\n",
    "    ep_stat['run_wall_time'] = time.time() - episode_start_time\n",
    "    ep_stat['training_time'] = train_time\n",
    "    ep_stat['trained'] = int(trained)\n",
    "    ep_stat['buffer_last_train'] = buffer_last_train\n",
    "    ep_stat['percent_new_experience'] = (buffer_size - ep_stats[-1]['buffer_size'])/buffer_last_train if len(ep_stats) > 0 else 0\n",
    "\n",
    "    ep_stats.append(ep_stat.copy())\n",
    "\n",
    "    print(\"buffer size %d\" %buffer_size , \"Rewards obtained:\", ep_reward, \"loss: %.7f+-%.5f\" % (losses_mean, losses_std),\n",
    "          \"Collection time: %.2fs\" % collection_time, \"\")\n",
    "\n",
    "# ep_stats = pd.DataFrame(ep_stats)\n",
    "\n",
    "# # ep_stats.to_csv(run_path + \"Run%d.csv\" % (0))\n",
    "# # print(run_path + \"Run%d.csv\" % (0))\n",
    "# # if args.SAVE:\n",
    "# #     self.save()\n",
    "\n",
    "# # print('done')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(ep_perc_recalc)\n",
    "plt.plot(by_step_df.Rewards)\n",
    "plt.show()\n",
    "plt.plot(by_step_df.Recalculated)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "rewards = [sample[\"rewards\"] for sample in samples][0]\n",
    "plt.plot(rewards)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "self.env.step([0.01])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "self.env.render()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
