{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os \n",
    "import numpy as np\n",
    "\n",
    "from prompt_dt.prompt_decision_transformer import ReachAvoidTransformer, GoalTransformer\n",
    "from prompt_dt.prompt_seq_trainer import PromptSequenceReachAvoidTrainer\n",
    "#from prompt_dt.prompt_utils import get_env_list\n",
    "#from prompt_dt.prompt_utils import get_prompt_batch, get_prompt, get_batch, get_batch_finetune\n",
    "#from prompt_dt.prompt_utils import process_total_data_mean, load_data_prompt, process_info\n",
    "from prompt_dt.prompt_utils import eval_episodes, get_prompt_batch\n",
    "from cosine_annealing_warmup import CosineAnnealingWarmupRestarts\n",
    "from raDT.baselines.envs.maps_obstacle import *\n",
    "\n",
    "import envs.cardiogenesis.utils as cardiogenesis_utils\n",
    "from envs.cardiogenesis.utils import CardiogenesisEnv\n",
    "# import envs.pick_and_place_obstacle.utils as pick_and_place_obstacle_utils\n",
    "\n",
    "from raDT.constants import *\n",
    "import torch\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "attractors = ['000000000000000',\n",
    " '000010010100000',\n",
    " '001111011011111',\n",
    " '010111111010011',\n",
    " '100000000001100',\n",
    " '100010010101100']\n",
    "\n",
    "start_idx = 1 # specify index of attractor state to be initial state of evaluation\n",
    "start_readable = attractors[start_idx]\n",
    "start = np.array([int(x) for x in start_readable])\n",
    "\n",
    "goal_readable = '100010010101100' # specify goal state\n",
    "goal = np.array([int(x) for x in goal_readable])\n",
    "\n",
    "avoids_readable = [] # specify avoid states\n",
    "fixed_interval = 10\n",
    "env = CardiogenesisEnv(fixed_interval, avoids_readable, goal_readable, start_readable)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 95,
   "metadata": {},
   "outputs": [],
   "source": [
    "state_dim = 15\n",
    "act_dim = 15\n",
    "prompt_dim = 15\n",
    "max_length = 1\n",
    "max_ep_len = 30\n",
    "hidden_size = 384\n",
    "n_layer = 6\n",
    "n_head = 6\n",
    "discrete_action = True\n",
    "avoid_prompt = True\n",
    "adelta=1\n",
    "model = ReachAvoidTransformer(\n",
    "                        state_dim=state_dim,\n",
    "                        act_dim=act_dim,\n",
    "                        action_space=env.action_space,\n",
    "                        prompt_dim=prompt_dim,\n",
    "                        max_length=max_length,\n",
    "                        max_ep_len=max_ep_len, \n",
    "                        hidden_size=hidden_size,\n",
    "                        n_layer=n_layer,\n",
    "                        n_head=n_head,\n",
    "                        n_inner=4 * hidden_size,\n",
    "                        activation_function='relu',\n",
    "                        n_positions=2048,\n",
    "                        resid_pdrop=0.1,\n",
    "                        attn_pdrop=0.1,\n",
    "                        adelta=1\n",
    "                        )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# load trained RADT model\n",
    "\n",
    "x = torch.load(\"INSERT/MODEL/PATH/HERE\", map_location=torch.device(\"cpu\"))\n",
    "device = 'cpu'\n",
    "model.load_state_dict(x, strict=True)\n",
    "model.to(device=device)\n",
    "model.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# set num_eval_episodes to the number of rollouts to evaluate on\n",
    "num_eval_episodes = 200\n",
    "\n",
    "successes = []\n",
    "costs = []\n",
    "returns = []\n",
    "ep_lens = []\n",
    "\n",
    "trajs_readable = []\n",
    "action_trajs_readable = []\n",
    "\n",
    "buffer_size = 0.001\n",
    "\n",
    "for i in range(num_eval_episodes):\n",
    "\n",
    "    states_readable = []\n",
    "    actions_readable = []\n",
    "\n",
    "    obs, info = env.reset()\n",
    "    state = obs[\"observation\"]\n",
    "    states_readable.append(\"\".join([str(x) for x in state]))\n",
    "\n",
    "\n",
    "    # customized: create a prompt to be whatever the environment goal is\n",
    "    goal = obs[\"desired_goal\"]\n",
    "    prompt_goal = torch.from_numpy(np.array([goal]).reshape(1, -1, prompt_dim)).to(dtype=torch.float32, device=device)\n",
    "    prompt_mask = torch.ones(prompt_goal.shape[:2]).to(device=device)\n",
    "    prompt = (prompt_goal, prompt_mask)\n",
    "\n",
    "    if avoids_readable:\n",
    "        print(\"debug\")\n",
    "        avoid = np.array([np.array([int(x) for x in readable]) for readable in avoids_readable])[0]\n",
    "        bts_avoid_states = [np.concatenate([avoid + np.array([-buffer_size] * prompt_dim), avoid + np.array([buffer_size] * prompt_dim)])]\n",
    "        avoid_states = [np.concatenate([avoid + np.array([-buffer_size] * prompt_dim), avoid + np.array([buffer_size] * prompt_dim)])]\n",
    "        avoid_prompt_state = torch.from_numpy(np.array(avoid_states).reshape(1, -1, prompt_dim*2)).to(dtype=torch.float32, device=device)\n",
    "        avoid_prompt_mask = torch.ones(avoid_prompt_state.shape[:2]).to(device=device)\n",
    "        avoid_prompt = (avoid_prompt_state, avoid_prompt_mask)\n",
    "    else:\n",
    "        avoid_prompt = None   \n",
    "\n",
    "    states = torch.from_numpy(state).reshape(1, state_dim).to(device=device, dtype=torch.float32)\n",
    "    if discrete_action:\n",
    "        actions = torch.zeros((0,), device=device, dtype=torch.long)\n",
    "    else:\n",
    "        actions = torch.zeros((0, act_dim), device=device, dtype=torch.float32)\n",
    "    timesteps = torch.tensor(0, device=device, dtype=torch.long).reshape(1, 1)\n",
    "\n",
    "    sim_states = []\n",
    "\n",
    "    episode_return, episode_cost_return, episode_length = 0, 0, 0\n",
    "    for t in range(max_ep_len):\n",
    "        if discrete_action:\n",
    "            actions = torch.cat([actions, torch.zeros((1,), device=device, dtype=torch.long)], dim=0)\n",
    "        else:\n",
    "            actions = torch.cat([actions, torch.zeros((1, act_dim), device=device)], dim=0)\n",
    "        action = model.get_action(\n",
    "            states.to(dtype=torch.float32),\n",
    "            actions,\n",
    "            timesteps.to(dtype=torch.long),\n",
    "            prompt=prompt,\n",
    "            avoid_prompt=avoid_prompt,\n",
    "            success_list=torch.tensor([True], device=device)\n",
    "        )\n",
    "            \n",
    "        actions[-1] = action\n",
    "        action = action.detach().cpu().numpy()\n",
    "\n",
    "        actions_readable.append(action)\n",
    "\n",
    "        obs, reward, cost, done, infos = env.step(action)\n",
    "        state = obs[\"observation\"]\n",
    "        states_readable.append(\"\".join([str(x) for x in state]))\n",
    "\n",
    "        cur_state = torch.from_numpy(state).to(device=device).reshape(1, state_dim)\n",
    "        states = torch.cat([states, cur_state], dim=0)\n",
    "        #rewards[-1] = reward\n",
    "        \n",
    "        timesteps = torch.cat(\n",
    "            [timesteps,\n",
    "             torch.ones((1, 1), device=device, dtype=torch.long) * (t+1)], dim=1)\n",
    "\n",
    "        episode_return += reward\n",
    "        episode_cost_return += cost\n",
    "        episode_length += 1\n",
    "        infos['episode_length'] = episode_length\n",
    "\n",
    "        if done or infos['is_success']: # stop when success is reached\n",
    "            break\n",
    "\n",
    "    trajs_readable.append(states_readable)\n",
    "    action_trajs_readable.append(actions_readable)\n",
    "    print(f\"Episode {i} normalized cost return:\", episode_cost_return)\n",
    "    print(f\"Episode {i} success:\", infos[\"is_success\"])\n",
    "    successes.append(infos[\"is_success\"])\n",
    "    costs.append(episode_cost_return)\n",
    "    ep_lens.append(episode_length)\n",
    "\n",
    "print(\"Mean length: \", np.mean(ep_lens))\n",
    "print(\"Mean success: \", np.mean(successes))\n",
    "print(\"Mean cost: \", np.mean(costs))\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
