{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-12-27T15:08:00.518601Z",
     "start_time": "2023-12-27T15:07:58.184122Z"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "libEGL warning: Not allowed to force software rendering when API explicitly selects a hardware device.\n",
      "libEGL warning: Not allowed to force software rendering when API explicitly selects a hardware device.\n",
      "libEGL warning: Not allowed to force software rendering when API explicitly selects a hardware device.\n",
      "libEGL warning: Not allowed to force software rendering when API explicitly selects a hardware device.\n",
      "pybullet build time: May 20 2022 19:45:31\n"
     ]
    }
   ],
   "source": [
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "import sys\n",
    "sys.path.append(\"..\")\n",
    "\n",
    "import os\n",
    "os.environ['D4RL_SUPPRESS_IMPORT_ERROR'] = '1'\n",
    "os.environ['CUDA_VISIBLE_DEVICES']='1'\n",
    "\n",
    "import gym\n",
    "import d4rl\n",
    "\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "plt.style.use('fivethirtyeight')\n",
    "\n",
    "from matplotlib.colors import LinearSegmentedColormap, ListedColormap\n",
    "from matplotlib import patches\n",
    "\n",
    "import equinox as eqx\n",
    "import jax\n",
    "import jax.numpy as jnp\n",
    "import functools\n",
    "\n",
    "from tqdm.auto import tqdm\n",
    "from jaxrl_m.common import TrainStateEQX\n",
    "from src.agents.iql_equinox import GaussianPolicy, GaussianIntentPolicy\n",
    "\n",
    "from ott.geometry import pointcloud\n",
    "from ott.problems.linear import linear_problem\n",
    "from ott.solvers.linear import sinkhorn\n",
    "from ott.tools import plot, sinkhorn_divergence\n",
    "from ott.solvers.linear import implicit_differentiation as imp_diff\n",
    "\n",
    "import optax\n",
    "\n",
    "@eqx.filter_vmap(in_axes=dict(ensemble=eqx.if_array(0), s=None))\n",
    "def eval_ensemble_psi(ensemble, s):\n",
    "    return eqx.filter_vmap(ensemble.psi_net)(s)\n",
    "\n",
    "@eqx.filter_vmap(in_axes=dict(ensemble=eqx.if_array(0), s=None))\n",
    "def eval_ensemble_phi(ensemble, s):\n",
    "    return eqx.filter_vmap(ensemble.phi_net)(s)\n",
    "\n",
    "@eqx.filter_vmap(in_axes=dict(ensemble=eqx.if_array(0), s=None, g=None, z=None))\n",
    "def eval_ensemble_icvf_viz(ensemble, s, g, z):\n",
    "    return eqx.filter_vmap(ensemble.classic_icvf_initial)(s, g, z)\n",
    "\n",
    "@eqx.filter_vmap(in_axes=dict(ensemble=eqx.if_array(0), s=None, g=None, z=None)) # V(s, g, z), g - dim 29, z - dim 256\n",
    "def eval_ensemble_icvf_latent_z(ensemble, s, g, z):\n",
    "    return eqx.filter_vmap(ensemble.classic_icvf)(s, g, z)\n",
    "\n",
    "@eqx.filter_vmap(in_axes=dict(ensemble=eqx.if_array(0), s=None, g=None, z=None)) # V(s, g ,z ), g, z - dim 256\n",
    "def eval_ensemble_icvf_latent_zz(ensemble, s, g, z):\n",
    "    return eqx.filter_vmap(ensemble.icvf_zz)(s, g, z)\n",
    "    \n",
    "@eqx.filter_vmap(in_axes=dict(ensemble=eqx.if_array(0), s=None, g=None, z=None))\n",
    "def eval_ensemble_icvf_latent_zzz(ensemble, s, g, z):\n",
    "    return eqx.filter_vmap(ensemble.icvf_zzz)(s, g, z)\n",
    "\n",
    "@eqx.filter_vmap(in_axes=dict(ensemble=None, s=0, z=None))\n",
    "def eqx_get_state_traj(ensemble, s, z):\n",
    "    '''\n",
    "    Function to compute pairwise distance between two trajectories\n",
    "    '''\n",
    "    s = jnp.tile(s, (z.shape[0], 1))\n",
    "    return eval_ensemble_icvf_latent_zzz(icvf_model.value_learner.model, s, z, z)\n",
    "\n",
    "@eqx.filter_jit\n",
    "def get_gcvalue(agent, s, g, z):\n",
    "    v_sgz_1, v_sgz_2 = eval_ensemble_icvf_viz(agent.value_learner.model, s, g, z)\n",
    "    return (v_sgz_1 + v_sgz_2) / 2\n",
    "\n",
    "def get_v_gz(agent, initial_state, target_goal, observations):\n",
    "    initial_state = jnp.tile(initial_state, (observations.shape[0], 1))\n",
    "    target_goal = jnp.tile(target_goal, (observations.shape[0], 1))\n",
    "    return -1 * get_gcvalue(agent, initial_state, observations, target_goal)\n",
    "    \n",
    "def get_v_zz(agent, goal, observations):\n",
    "    goal = jnp.tile(goal, (observations.shape[0], 1))\n",
    "    return get_gcvalue(agent, observations, goal, goal)\n",
    "\n",
    "@eqx.filter_vmap(in_axes=dict(agent=None, obs=None, goal=0))\n",
    "def get_v_zz_heatmap(agent, obs, goal): # goal - traj\n",
    "    goal = jnp.tile(goal, (obs.shape[0], 1))\n",
    "    return get_gcvalue(agent, obs, goal, goal)\n",
    "\n",
    "@eqx.filter_vmap(in_axes=dict(ensemble=None, s=0))\n",
    "@eqx.filter_vmap(in_axes=dict(ensemble=eqx.if_array(0), s=None))\n",
    "def batched_eval_ensemble_psi(ensemble, s):\n",
    "    return eqx.filter_vmap(ensemble.psi_net)(s)\n",
    "\n",
    "jnp.set_printoptions(precision=2, suppress=True)\n",
    "\n",
    "%matplotlib inline\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "\n",
    "scales_shifts = [(6, 40), (-6, 230)]\n",
    "\n",
    "def generate_maze_img(ax, fig, n=50, icvf_values=\"icvf\", intents_learner=None, state_list=None, traj=None):\n",
    "    torso_x, torso_y = env.env.env.wrapped_env._init_torso_x, env.env.env.wrapped_env._init_torso_y\n",
    "    S = env.env.env.wrapped_env._maze_size_scaling\n",
    "    points = XY(env, n=n)\n",
    "    whole_grid = traj[-1]\n",
    "    whole_grid = np.tile(whole_grid, (points.shape[0], 1))\n",
    "    whole_grid[:, :2] = points\n",
    "    \n",
    "    if icvf_values is None:\n",
    "        Z = np.random.rand(n, n)\n",
    "        im = ax.pcolormesh(points[:, 0].reshape(n, n), points[:, 1].reshape(n, n), Z.reshape(n, n), edgecolor='black')\n",
    "    elif icvf_values == 'icvf':\n",
    "        Z = get_v_zz_heatmap(icvf_model, whole_grid, traj).mean(0)\n",
    "        im = ax.pcolormesh(points[:, 0].reshape(n, n), points[:, 1].reshape(n, n), Z.reshape(n, n), cmap='RdBu_r', edgecolor='black')\n",
    "    elif icvf_values == \"likelihood\" and intents_learner is not None:\n",
    "        points_dist = eqx.filter_vmap(intents_learner)(whole_grid)\n",
    "        likelihood = points_dist.sample_and_log_prob(seed=key)[1]\n",
    "        im = ax.pcolormesh(points[:, 0].reshape(n, n), points[:, 1].reshape(n, n), likelihood.reshape(n, n), cmap='RdBu_r', edgecolor='black')\n",
    "    \n",
    "    for i in range(len(env.env.env.wrapped_env._maze_map)):\n",
    "        for j in range(len(env.env.env.wrapped_env._maze_map[0])):\n",
    "            struct = env.env.env.wrapped_env._maze_map[i][j]\n",
    "            if struct == 1:\n",
    "                rect = patches.Rectangle((j *S - torso_x - S/ 2,\n",
    "                                        i * S- torso_y - S/ 2),\n",
    "                                        S,\n",
    "                                        S, linewidth=1, facecolor='RosyBrown', alpha=1.0)\n",
    "                ax.add_patch(rect)\n",
    "    ax.set_xlim(0 - S /2 + 0.6 * S - torso_x, len(env.env.env.wrapped_env._maze_map[0]) * S - torso_x - S/2 - S * 0.6)\n",
    "    ax.set_ylim(0 - S/2 + 0.6 * S - torso_y, len(env.env.env.wrapped_env._maze_map) * S - torso_y - S/2 - S * 0.6)\n",
    "    fig.colorbar(im)\n",
    "    return ax\n",
    "\n",
    "def get_starting_boundary(env):\n",
    "    torso_x, torso_y = env.env.env.wrapped_env._init_torso_x, env.env.env.wrapped_env._init_torso_y\n",
    "    maze_map = env.env.env.wrapped_env._maze_map\n",
    "    S = env.env.env.wrapped_env._maze_size_scaling\n",
    "    return (0 - S / 2 + S - torso_x, 0 - S/2 + S - torso_y), (len(maze_map[0]) * S - torso_x - S/2 - S, len(maze_map) * S - torso_y - S/2 - S)\n",
    "\n",
    "def XY(env, n=10):\n",
    "    bl, tr = get_starting_boundary(env)\n",
    "    X = np.linspace(bl[0] + 0.04 * (tr[0] - bl[0]) , tr[0] - 0.04 * (tr[0] - bl[0]), n)\n",
    "    Y = np.linspace(bl[1] + 0.04 * (tr[1] - bl[1]) , tr[1] - 0.04 * (tr[1] - bl[1]), n)\n",
    "    \n",
    "    X,Y = np.meshgrid(X,Y)\n",
    "    states = np.array([X.flatten(), Y.flatten()]).T\n",
    "    return states\n",
    "\n",
    "def scale_and_shift(x, lst):\n",
    "    return lst[0] * x + lst[1]\n",
    "    \n",
    "def plot_traj_image(traj, start, goal, intents, expert_trajectory, bgpath):\n",
    "    fig, ax = plt.subplots(figsize=(5, 5))\n",
    "    pimage = plt.imread(bgpath)\n",
    "    ax.imshow(pimage)\n",
    "    ax.set_xticks([])\n",
    "    ax.set_yticks([])\n",
    "    ax.scatter(scale_and_shift(expert_trajectory[:, 0], scales_shifts[0]), scale_and_shift(expert_trajectory[:, 1], scales_shifts[1]), alpha=1, label='trajectory', color='orange')\n",
    "    ax.scatter(scale_and_shift(start[0], scales_shifts[0]), scale_and_shift(start[1], scales_shifts[1]), c='g', s=100, label='start')\n",
    "    ax.scatter(scale_and_shift(goal[0], scales_shifts[0]), scale_and_shift(goal[1], scales_shifts[1]), c='r', s=100, label='goal')\n",
    "    ax.scatter(scale_and_shift(intents[0], scales_shifts[0]), scale_and_shift(intents[1], scales_shifts[1]), c='blue', s=100, label='intents')\n",
    "    ax.legend(fontsize=10)\n",
    "\n",
    "def icvf_stat(batch):\n",
    "    obs = batch[\"observations\"]\n",
    "    subgoal = batch[\"ailot_high_targets\"]\n",
    "    goal = batch[\"ailot_high_targets\"]\n",
    "    advantages = get_v_zz(icvf_model, goal, subgoal) - get_v_zz(icvf_model, goal, obs)\n",
    "    value = get_v_zz(icvf_model, obs, subgoal) - get_v_zz(icvf_model, obs, goal)\n",
    "    sg_repr_dist = np.linalg.norm(eval_ensemble_psi(icvf_model.value_learner.model, obs) - eval_ensemble_psi(icvf_model.value_learner.model, subgoal))\n",
    "    sg_repr = eval_ensemble_psi(icvf_model.value_learner.model, subgoal)[0]\n",
    "    print(sg_repr_dist)\n",
    "    \n",
    "    print(\"V(sg, g, g)=\", get_v_zz(icvf_model, goal, subgoal))\n",
    "    print(\"V(s, g, g)=\", get_v_zz(icvf_model, goal, obs))\n",
    "    print(\"V(s, sg, sg)=\", get_v_zz(icvf_model, subgoal, obs))\n",
    "    print(\"V(sg, s, s)=\", get_v_zz(icvf_model, obs, subgoal))\n",
    "    print(\"Advantage of going to sg from s: \", advantages)\n",
    "    \n",
    "    return advantages, value\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "ExecuteTime": {
     "start_time": "2023-12-27T15:07:58.310Z"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "load datafile: 100%|██████████████████████████████████████████████████████████████████████████████████| 9/9 [00:04<00:00,  1.82it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "fd04d5719a9a4d5eae335d80da0a60ae",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/1998000 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Expert returns [11086.897860401426, 11124.63724167645, 11092.346300512552, 11188.732963606715, 11239.283746674657], mean 11146.37962257436\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "load datafile: 100%|██████████████████████████████████████████████████████████████████████████████████| 9/9 [00:04<00:00,  1.83it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Number of terminal states: 1\n",
      "Number of terminal states: 2001\n"
     ]
    }
   ],
   "source": [
    "from src.gc_dataset import GCSDataset\n",
    "from utils.ds_builder import setup_datasets\n",
    "\n",
    "env, expert_ds, agent_ds, agent_mean_states, agent_std_states = setup_datasets(expert_env_name=\"halfcheetah-medium-expert-v2\",\n",
    "                                          agent_env_name=\"halfcheetah-medium-expert-v2\", expert_num=5,\n",
    "                                          normalize_agent_states=False)\n",
    "\n",
    "gcsds_params = GCSDataset.get_default_config()\n",
    "gc_expert_dataset = GCSDataset(expert_ds, **gcsds_params)\n",
    "gc_agent_dataset = GCSDataset(agent_ds, **gcsds_params)\n",
    "\n",
    "expert_trajectory = gc_expert_dataset.dataset.dataset_dict['observations']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/home/m_bobrin/AILOT\n",
      "Extra kwargs: {}\n"
     ]
    }
   ],
   "source": [
    "from src.agents import icvf\n",
    "%cd ..\n",
    "icvf_model = icvf.create_eqx_learner(seed=42,\n",
    "                                     observations=gc_agent_dataset.dataset.dataset_dict['observations'][0],\n",
    "                                     hidden_dims=[256, 256],\n",
    "                                     pretrained_folder=\"halfcheetah-medium-expert\",\n",
    "                                     load_pretrained_icvf=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# UNCOMMENT IF YOU HAVE ONLY ONE TRAJECTORY\n",
    "%matplotlib inline\n",
    "\n",
    "fig, ax = plt.subplots(figsize=(7, 5))\n",
    "\n",
    "generate_maze_img(ax, fig, traj=expert_trajectory, icvf_values='icvf')\n",
    "\n",
    "plt.scatter(expert_trajectory[:, 0], expert_trajectory[:, 1], alpha=1, label='trajectory', color='orange')\n",
    "plt.scatter(expert_trajectory[150, 0], expert_trajectory[150, 1], alpha=1, label='trajectory', color='blue')\n",
    "plt.scatter(expert_trajectory[350, 0], expert_trajectory[350, 1], alpha=1, label='trajectory', color='black')\n",
    "plt.legend(loc='upper left', fontsize=10)\n",
    "fig.tight_layout()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "# IQL (To check if IQL works on full D4RL dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "wandb.init(project='D4RL-jupyter', group='IQL-base', name=env.spec.id,\n",
    "          config={'max_steps': 1_000_000, 'seed': 329399, 'expectile':0.9, 'discount': 0.999, 'temperature': 6})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "from src.agents.iql_flax.common import Batch\n",
    "from src.agents.iql_flax.learner import Learner\n",
    "from src.agents.iql_flax.evaluation import evaluate\n",
    "\n",
    "max_steps = 1_000_000\n",
    "\n",
    "iql_agent = Learner(\n",
    "        329399,\n",
    "        env.observation_space.sample()[np.newaxis],\n",
    "        env.action_space.sample()[np.newaxis],\n",
    "        max_steps=max_steps,\n",
    "        expectile=0.9,\n",
    "        discount=0.999,\n",
    "        temperature=6)\n",
    "\n",
    "pbar = tqdm(range(max_steps))\n",
    "for i in pbar:\n",
    "    sample = gc_agent_dataset.dataset.sample(batch_size=256)\n",
    "    batch = Batch(\n",
    "        observations=sample[\"observations\"],\n",
    "        next_observations=sample[\"next_observations\"],\n",
    "        actions = sample[\"actions\"],\n",
    "        rewards = sample[\"rewards\"], \n",
    "        masks= sample[\"masks\"]\n",
    "    )\n",
    "    update_info = iql_agent.update(batch)\n",
    "    update_info['adv'] = None\n",
    "    if i % 50_000 == 0 and i > 0:\n",
    "        eval_stats = evaluate(iql_agent, env, num_episodes=10)\n",
    "        wandb.log({'Eval': eval_stats})\n",
    "        print(eval_stats)\n",
    "        pbar.set_postfix(update_info)\n",
    "    if i % 3000 == 0:\n",
    "        wandb.log({'Training/': update_info})\n",
    "        pbar.set_postfix(update_info)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "eval_stats = evaluate(iql_agent, env, num_episodes=100)\n",
    "wandb.log({f\"Final Eval/{k}\": stat for k, stat in eval_stats.items()})\n",
    "print(eval_stats)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "wandb.finish()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from gymnasium.utils import save_video\n",
    "\n",
    "scales_shifts = [(6, 40), (-6, 230)]\n",
    "def scale_and_shift(x, lst):\n",
    "    return lst[0] * x + lst[1]\n",
    "    \n",
    "def plot_traj_image(traj, start, goal, subgoals, bgpath):\n",
    "    fig, ax = plt.subplots(figsize=(5, 5))\n",
    "    pimage = plt.imread(bgpath)\n",
    "    ax.imshow(pimage)\n",
    "    ax.set_xticks([])\n",
    "    ax.set_yticks([])\n",
    "    ax.scatter(scale_and_shift(expert_trajectory[:, 0], scales_shifts[0]), scale_and_shift(expert_trajectory[:, 1], scales_shifts[1]), alpha=1, label='trajectory', color='orange')\n",
    "    ax.scatter(scale_and_shift(start[0], scales_shifts[0]), scale_and_shift(start[1], scales_shifts[1]), c='g', s=100, label='start')\n",
    "    ax.scatter(scale_and_shift(goal[0], scales_shifts[0]), scale_and_shift(goal[1], scales_shifts[1]), c='r', s=100, label='goal')\n",
    "    ax.scatter(scale_and_shift(subgoals[0], scales_shifts[0]), scale_and_shift(subgoals[1], scales_shifts[1]), c='b', s=100, label='S_{t+k}')\n",
    "    ax.legend(fontsize=10)\n",
    "\n",
    "sample = gc_agent_dataset.dataset.sample(1)\n",
    "start_point = expert_trajectory[0]\n",
    "target_goal = expert_trajectory[-1]\n",
    "sample_key = jax.random.PRNGKey(42)\n",
    "\n",
    "env.reset()\n",
    "env.env.env.wrapped_env.set_xy((start_point[0], start_point[1]))\n",
    "env.env.env.wrapped_env.set_target((target_goal[0], target_goal[1]))\n",
    "start_point = env.env.env.wrapped_env._get_obs()\n",
    "curr_point = start_point\n",
    "frames=[]\n",
    "\n",
    "i = 0\n",
    "done = False\n",
    "while not done:\n",
    "    key, sample_key = jax.random.split(sample_key, 2)\n",
    "    action = jax.device_get(iql_agent.sample_actions(curr_point.squeeze(), temperature=0.0))\n",
    "    new_obs, reward, done ,_ = env.step(action)\n",
    "    \n",
    "    os.environ['CUDA_VISIBLE_DEVICES']='4'\n",
    "    frames.append(env.render(mode='rgb_array'))\n",
    "    os.environ['CUDA_VISIBLE_DEVICES']='1,2,3'\n",
    "    if done:\n",
    "        print(reward)  \n",
    "    if i % 100 == 0:\n",
    "        plot_traj_image(sample, new_obs, target_goal, new_obs, \"/home/m_bobrin/AILOT/notebooks/antmaze-large.png\")\n",
    "\n",
    "    plt.show()\n",
    "    curr_point = new_obs\n",
    "    i+=1\n",
    "save_video.save_video(frames, video_folder='.', fps=30)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "# VAE test (with actions from expert & agent)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "wandb.init(project='D4RL-jupyter', group='IQL-VAE-rewards', name=env.spec.id,\n",
    "          config={'max_steps': 1_000_000, 'seed': 1337, 'expectile':0.95, 'discount': 0.99, 'temperature': 3,\n",
    "                 'latent_dim': 16, 'hidden_dim': 512, 'lr': 1e-4})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from jaxtyping import Key\n",
    "\n",
    "class VAE(eqx.Module):\n",
    "    encoder: eqx.Module\n",
    "    decoder: eqx.Module\n",
    "    key: Key\n",
    "    latent_dim: int \n",
    "\n",
    "    def __init__(self, state_dim, action_dim, latent_dim, hidden_dim, key):\n",
    "        key, encoder_key, decoder_key, mean_key, std_key = jax.random.split(key, 5)\n",
    "\n",
    "        self.latent_dim = latent_dim\n",
    "        self.key = key\n",
    "        \n",
    "        self.encoder = eqx.nn.MLP(in_size = state_dim + action_dim, out_size=latent_dim * 2,\n",
    "                                  width_size=hidden_dim, depth=2, key=encoder_key) # latent*2 since split into mean, std\n",
    "        self.decoder = eqx.nn.MLP(in_size = latent_dim + state_dim, out_size=action_dim,\n",
    "                                  width_size=hidden_dim, depth=2, key=decoder_key)\n",
    "    def __call__(self, state, action, random_key):\n",
    "        # Encoder part\n",
    "        mean, std = jnp.split(self.encoder(jnp.concatenate([state, action], axis=-1)), 2, axis=-1)\n",
    "        log_std = jnp.exp(jnp.clip(std, a_min=-4, a_max=15))\n",
    "\n",
    "        latent_vec = mean + log_std * jax.random.normal(key=random_key, shape=log_std.shape)\n",
    "        decoded = self.decode(state, latent_vec)\n",
    "        \n",
    "        return mean, log_std, decoded\n",
    "        \n",
    "    def decode(self, state, latent=None):\n",
    "        if latent is None:\n",
    "            self.key, latent_key = jax.random.split(self.key, 2)\n",
    "            z = jax.random.normal(key=latent_key, shape=(1, self.latent_dim))\n",
    "        a = self.decoder(jnp.concatenate([state, latent], axis=-1))\n",
    "        return jax.nn.tanh(a)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import dataclasses\n",
    "\n",
    "class TrainStateEQX(eqx.Module):\n",
    "    model: eqx.Module\n",
    "    optim: optax.GradientTransformation\n",
    "    optim_state: optax.OptState\n",
    "\n",
    "    @classmethod\n",
    "    def create(cls, *, model, optim, **kwargs):\n",
    "        optim_state = optim.init(eqx.filter(model, eqx.is_array))\n",
    "        return cls(model=model, optim=optim, optim_state=optim_state,\n",
    "                   **kwargs)\n",
    "    \n",
    "    @eqx.filter_jit\n",
    "    def apply_updates(self, grads):\n",
    "        updates, new_optim_state = self.optim.update(grads, self.optim_state, self.model)\n",
    "        new_model = eqx.apply_updates(self.model, updates)\n",
    "        return dataclasses.replace(\n",
    "            self,\n",
    "            model=new_model,\n",
    "            optim_state=new_optim_state\n",
    "        )\n",
    "\n",
    "def vae_loss(model, states, actions, random_key):\n",
    "    #random_key = jax.random.split(random_key, states.shape[0])\n",
    "    mean, std, recon = eqx.filter_vmap(model, in_axes=(0, 0, None))(states, actions, random_key)\n",
    "    expert_std = std[-5:, :]\n",
    "    expert_mean = mean[-5:, :]\n",
    "    expert_std_loss = jnp.var(expert_std, 0).mean()\n",
    "    expert_mean_loss = jnp.var(expert_mean, 0).mean()\n",
    "    recon_loss = jnp.mean(jnp.square(recon - actions)) #jnp.mean(jnp.square(recon - actions)) #jnp.linalg.norm(recon - actions)\n",
    "    KL_loss = -0.5 * (1 + jnp.log(std**2) - mean ** 2 - std**2).mean()\n",
    "    vae_loss = recon_loss + 0.5 * KL_loss + 0.8 * expert_std_loss + 0.8 * expert_mean_loss\n",
    "    return vae_loss, (mean, std, recon, KL_loss, recon_loss)\n",
    "    \n",
    "@eqx.filter_jit\n",
    "def make_step(vae_learner, states, actions, random_key):\n",
    "    (loss, stats), grads = eqx.filter_value_and_grad(vae_loss, has_aux=True)(vae_learner.model, states, actions, random_key)\n",
    "    mean, std, recon, KL_loss, recon_loss = stats\n",
    "    new_vae = vae_learner.apply_updates(grads)\n",
    "    return new_vae, mean, std, recon, loss, KL_loss, recon_loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from dataclasses import dataclass\n",
    "\n",
    "@dataclass\n",
    "class VAEcfg:\n",
    "    state_dim: int = env.observation_space.shape[0]\n",
    "    action_dim: int = env.action_space.shape[0]\n",
    "    latent_dim: int = action_dim * 2\n",
    "    hidden_dim: int = 512\n",
    "\n",
    "    lambda_loss: float = 1.0 # calibration mean and std loss weight\n",
    "    beta: float = 0.5\n",
    "    batch_size: int = 256\n",
    "    num_iters: int = 100_000\n",
    "    lr: float = 1e-4\n",
    "    weight_decay: float = 1e-4\n",
    "    #test lr scheduler\n",
    "\n",
    "vae_cfg = VAEcfg()\n",
    "vae_key = jax.random.PRNGKey(1337)\n",
    "vae = VAE(key=vae_key, state_dim=vae_cfg.state_dim, action_dim=vae_cfg.action_dim, latent_dim=vae_cfg.latent_dim, hidden_dim=vae_cfg.hidden_dim)\n",
    "vae_learner = TrainStateEQX.create(model=vae, optim=optax.adam(vae_cfg.lr))\n",
    "\n",
    "pbar = tqdm(range(vae_cfg.num_iters), desc='Training VAE')\n",
    "for learn_step in pbar:\n",
    "    agent_sample = gc_agent_dataset.dataset.sample(vae_cfg.batch_size - 5)\n",
    "    expert_sample = gc_expert_dataset.dataset.sample(5)\n",
    "    states = jnp.vstack([agent_sample['observations'], expert_sample['observations']])\n",
    "    actions = jnp.vstack([agent_sample['actions'], expert_sample['actions']])\n",
    "\n",
    "    vae_key, random_key = jax.random.split(vae_key, 2)\n",
    "    vae_learner, mean, std, recon, loss, kl, recon_loss = make_step(vae_learner, states, actions, random_key)\n",
    "    if learn_step % 500 == 0:\n",
    "        wandb.log({'VAE-Train/Loss': loss,\n",
    "                  'VAE-Train/ReconLoss': recon_loss,\n",
    "                  'VAE-Train/KL': kl})\n",
    "        pbar.set_postfix({\"Loss\": loss,\n",
    "                          'Recon loss':recon_loss,\n",
    "                     \"KL\": kl})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "editable": true,
    "slideshow": {
     "slide_type": ""
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "exp_actions = gc_expert_dataset.dataset.dataset_dict['actions']\n",
    "exp_states = gc_expert_dataset.dataset.dataset_dict['observations']\n",
    "mean_exp, std_exp, _ = eqx.filter_vmap(vae_learner.model, in_axes=(0, 0, None))(exp_states, exp_actions, vae_key)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "\n",
    "agent_actions = gc_agent_dataset.dataset.dataset_dict['actions'][24000:25000]\n",
    "agent_states = gc_agent_dataset.dataset.dataset_dict['observations'][24000:25000]\n",
    "\n",
    "vae_key = jax.random.split(jax.random.PRNGKey(43), agent_actions.shape[0])\n",
    "mean_ag, std_ag, _ = eqx.filter_vmap(vae_learner.model, in_axes=(0, 0, 0))(agent_states, agent_actions, vae_key)\n",
    "mean_center = mean_exp.mean(1)\n",
    "std_center = std_exp.mean(1)\n",
    "\n",
    "mean_center_ag = mean_ag.mean(1)\n",
    "std_center_ag = std_ag.mean(1)\n",
    "\n",
    "fig, ax = plt.subplots()\n",
    "plt.scatter(mean_center, std_center, c='red', label='Mean&Std Expert Traj')\n",
    "plt.scatter(mean_center_ag, std_center_ag, c='blue', label='Mean&Std Agent Traj')\n",
    "plt.legend()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_rewards_per_step(gc_agent_ds, mean_center):\n",
    "    agent_states = gc_agent_ds.dataset_dict['observations']\n",
    "    agent_actions = gc_agent_ds.dataset_dict['actions']\n",
    "    mean, log_std, decoded = eqx.filter_vmap(vae_learner.model, in_axes=(0, 0, None))(agent_states, agent_actions, jax.random.PRNGKey(228))\n",
    "    rewards = jnp.exp(-jnp.linalg.norm(mean - mean_center[None], axis=-1))\n",
    "    return rewards\n",
    "    \n",
    "rewards = compute_rewards_per_step(gc_agent_dataset.dataset, mean_exp.mean(0))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "rewards = np.asarray(jax.device_get(rewards))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Apply iql scaling\n",
    "from utils.ds_builder import load_trajectories\n",
    "    \n",
    "offline_traj = load_trajectories(\"antmaze-large-diverse-v2\", rewards)\n",
    "    \n",
    "def compute_iql_reward_scale(trajs):\n",
    "    \"\"\"Rescale rewards based on max/min from the dataset.\n",
    "    This is also used in the original IQL implementation.\n",
    "    \"\"\"\n",
    "    trajs = trajs.copy()\n",
    "    \n",
    "    def compute_returns(tr):\n",
    "        return sum([step[2] for step in tr])\n",
    "    \n",
    "    trajs.sort(key=compute_returns)\n",
    "    reward_scale = 1000.0 / (\n",
    "      compute_returns(trajs[-1]) - compute_returns(trajs[0]))\n",
    "    return reward_scale\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "iql_scale = compute_iql_reward_scale(offline_traj)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%pwd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.agents.iql_flax.common import Batch\n",
    "from src.agents.iql_flax.learner import Learner\n",
    "from src.agents.iql_flax.evaluation import evaluate\n",
    "        \n",
    "iql_agent_vae = Learner(\n",
    "        1337,\n",
    "        env.observation_space.sample()[np.newaxis],\n",
    "        env.action_space.sample()[np.newaxis],\n",
    "        max_steps=1_000_000,\n",
    "        expectile=0.9,\n",
    "        discount=0.999,\n",
    "        temperature=6)\n",
    "\n",
    "pbar = tqdm(range(1_000_000))\n",
    "for i in pbar:\n",
    "    rand_indx = np.random.randint(low=0, high=rewards.shape[0], size=512)\n",
    "    sample = gc_agent_dataset.dataset.sample(batch_size=512, indx=rand_indx)\n",
    "    batch = Batch(\n",
    "        observations=sample[\"observations\"],\n",
    "        next_observations=sample[\"next_observations\"],\n",
    "        actions = sample[\"actions\"],\n",
    "        rewards = rewards[np.asarray(rand_indx)] * iql_scale,# - 2,\n",
    "        masks= sample[\"masks\"]\n",
    "    )\n",
    "    update_info = iql_agent_vae.update(batch)\n",
    "    update_info['adv'] = None\n",
    "    if i % 50_000 == 0 and i > 0:\n",
    "        eval_stats = evaluate(iql_agent_vae, env, num_episodes=10)\n",
    "        wandb.log({'Training/eval' :eval_stats})\n",
    "        print(eval_stats)\n",
    "        pbar.set_postfix(update_info)\n",
    "    if i % 3000 == 0:\n",
    "        wandb.log({'Training/Metrics' :update_info})\n",
    "        pbar.set_postfix(update_info)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "eval_stats = evaluate(iql_agent_vae, env, num_episodes=100)\n",
    "wandb.log({f\"Final Eval/{k}\": stat for k, stat in eval_stats.items()})\n",
    "eval_stats"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "wandb.finish()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "eval_stats = evaluate(iql_agent_vae, env, num_episodes=100)\n",
    "eval_stats"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from gymnasium.utils import save_video\n",
    "\n",
    "scales_shifts = [(6, 40), (-6, 230)]\n",
    "def scale_and_shift(x, lst):\n",
    "    return lst[0] * x + lst[1]\n",
    "    \n",
    "def plot_traj_image(traj, start, goal, subgoals, bgpath):\n",
    "    fig, ax = plt.subplots(figsize=(5, 5))\n",
    "    pimage = plt.imread(bgpath)\n",
    "    ax.imshow(pimage)\n",
    "    ax.set_xticks([])\n",
    "    ax.set_yticks([])\n",
    "    ax.scatter(scale_and_shift(expert_trajectory[:, 0], scales_shifts[0]), scale_and_shift(expert_trajectory[:, 1], scales_shifts[1]), alpha=1, label='trajectory', color='orange')\n",
    "    ax.scatter(scale_and_shift(start[0], scales_shifts[0]), scale_and_shift(start[1], scales_shifts[1]), c='g', s=100, label='start')\n",
    "    ax.scatter(scale_and_shift(goal[0], scales_shifts[0]), scale_and_shift(goal[1], scales_shifts[1]), c='r', s=100, label='goal')\n",
    "    ax.scatter(scale_and_shift(subgoals[0], scales_shifts[0]), scale_and_shift(subgoals[1], scales_shifts[1]), c='b', s=100, label='S_{t+k}')\n",
    "    ax.legend(fontsize=10)\n",
    "\n",
    "sample = gc_agent_dataset.dataset.sample(1)\n",
    "start_point = expert_trajectory[0]\n",
    "target_goal = expert_trajectory[-1]\n",
    "sample_key = jax.random.PRNGKey(42)\n",
    "\n",
    "env.reset()\n",
    "env.env.env.wrapped_env.set_xy((start_point[0], start_point[1]))\n",
    "env.env.env.wrapped_env.set_target((target_goal[0], target_goal[1]))\n",
    "start_point = env.env.env.wrapped_env._get_obs()\n",
    "curr_point = start_point\n",
    "frames=[]\n",
    "\n",
    "i = 0\n",
    "done = False\n",
    "while not done:\n",
    "    key, sample_key = jax.random.split(sample_key, 2)\n",
    "    action = jax.device_get(iql_agent_vae.sample_actions(curr_point.squeeze(), temperature=0.0))\n",
    "    new_obs, reward, done ,_ = env.step(action)\n",
    "    \n",
    "    os.environ['CUDA_VISIBLE_DEVICES']='4'\n",
    "    frames.append(env.render(mode='rgb_array'))\n",
    "    os.environ['CUDA_VISIBLE_DEVICES']='1,2,3'\n",
    "    if done:\n",
    "        print(reward)  \n",
    "    if i % 100 == 0:\n",
    "        plot_traj_image(sample, new_obs, target_goal, new_obs, \"/home/m_bobrin/AILOT/notebooks/antmaze-large.png\")\n",
    "\n",
    "    plt.show()\n",
    "    curr_point = new_obs\n",
    "    i+=1\n",
    "# save_video.save_video(frames, video_folder='.', fps=env.env.env._wrapped_env.metadata['render_fps'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Our Algo"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "jupyter": {
     "source_hidden": true
    }
   },
   "outputs": [],
   "source": [
    "from abc import ABC, abstractmethod\n",
    "from jax.numpy import ndarray\n",
    "from ott.geometry import costs\n",
    "from ott.math import utils as mu\n",
    "\n",
    "class OTRewardsExpert:\n",
    "\n",
    "    def __init__(\n",
    "        self, expert_traj,\n",
    "    ):\n",
    "        self.expert_states = expert_traj\n",
    "        self.expert_z = batched_eval_ensemble_psi(icvf_model.value_learner.model, expert_traj).mean(axis=1) # for several expert trajs\n",
    "        self.sub_steps = 1\n",
    "\n",
    "    def make_subs(self, z, sub_steps):\n",
    "        sub_indx = jnp.minimum(jnp.arange(0, z.shape[0]) + sub_steps, z.shape[0] - 1)\n",
    "        return jax.tree_map(lambda arr: arr[sub_indx], z)\n",
    "    \n",
    "\n",
    "    @eqx.filter_jit\n",
    "    def get_z_and_start_index(self, obs):\n",
    "        z = eval_ensemble_psi(icvf_model.value_learner.model, obs).mean(axis=0)\n",
    "        diff = jnp.linalg.norm(z[0][None] - self.expert_z, axis=-1)\n",
    "        i_min = diff.argmin(axis=-1)#jnp.argmin((diff**2).sum(-1)).squeeze()\n",
    "        return z, i_min, diff\n",
    "\n",
    "    def compute_rewards(\n",
    "        self,\n",
    "        dataset\n",
    "    ):\n",
    "        i0 = 0\n",
    "        rewards = []\n",
    "        observations = dataset.dataset_dict['observations']\n",
    "        episode_starts, episode_ends, episode_returns = gc_agent_dataset.dataset._trajectory_boundaries_and_returns()\n",
    "        \n",
    "        for i1 in tqdm(range(len(episode_starts))):\n",
    "            zi, start_index, diff = self.get_z_and_start_index(observations[episode_starts[i1]:episode_ends[i1]])\n",
    "            ri = self.compute_rewards_one_episode(zi, self.expert_z[:, start_index:])\n",
    "            rewards.append(jax.device_get(ri))\n",
    "                  \n",
    "        return np.concatenate(rewards)\n",
    "\n",
    "    @eqx.filter_jit\n",
    "    def compute_rewards_one_episode(\n",
    "        self, episode_obs, expert_obs\n",
    "    ):\n",
    "\n",
    "        za_1 = episode_obs\n",
    "        za_2 = self.make_subs(za_1, self.sub_steps)\n",
    "        x = jnp.concatenate([za_1, za_2], axis=1)\n",
    "        \n",
    "        ze_1 = expert_obs\n",
    "        ze_2 = self.make_subs(ze_1, self.sub_steps)\n",
    "        y = jnp.concatenate([ze_1, ze_2], axis=1)\n",
    "        \n",
    "        geom = pointcloud.PointCloud(x, y, epsilon=0.001)\n",
    "        ot_prob = linear_problem.LinearProblem(geom)\n",
    "        solver = sinkhorn.Sinkhorn(max_iterations=300, use_danskin=True)\n",
    "\n",
    "        ot_sink = solver(ot_prob)\n",
    "        transp_cost = jnp.sum(ot_sink.matrix * geom.cost_matrix, axis=1)\n",
    "        rewards = -transp_cost * episode_obs.shape[0] / 10\n",
    "        print(rewards.shape)\n",
    "        return rewards"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "@eqx.filter_vmap(in_axes=dict(ensemble=None, s=0))\n",
    "@eqx.filter_vmap(in_axes=dict(ensemble=eqx.if_array(0), s=None))\n",
    "def batched_eval_ensemble_psi(ensemble, s):\n",
    "    return eqx.filter_vmap(ensemble.psi_net)(s)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(4995, 17)"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "expert_trajectory.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "expert_trajectory = expert_trajectory.reshape(5, -1, env.observation_space.shape[0]) # first arg - number of expert trajs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "ExecuteTime": {
     "start_time": "2023-12-27T15:08:01.109Z"
    }
   },
   "outputs": [],
   "source": [
    "from abc import ABC, abstractmethod\n",
    "from jax.numpy import ndarray\n",
    "from ott.geometry import costs\n",
    "from ott.math import utils as mu\n",
    "\n",
    "class OTRewardsExpert:\n",
    "\n",
    "    def __init__(\n",
    "        self, expert_traj,\n",
    "    ):\n",
    "        self.expert_states = expert_traj\n",
    "        self.expert_z = batched_eval_ensemble_psi(icvf_model.value_learner.model, expert_traj).mean(axis=1) # for several expert trajs\n",
    "        self.sub_steps = 1\n",
    "\n",
    "    def make_subs(self, z, sub_steps):\n",
    "        sub_indx = jnp.minimum(jnp.arange(0, z.shape[0]) + sub_steps, z.shape[0] - 1)\n",
    "        return jax.tree_map(lambda arr: arr[sub_indx], z)\n",
    "    \n",
    "\n",
    "    @eqx.filter_jit\n",
    "    def get_z_and_start_index(self, obs):\n",
    "        z = eval_ensemble_psi(icvf_model.value_learner.model, obs).mean(axis=0)\n",
    "        diff = jnp.linalg.norm(z[0][None] - self.expert_z, axis=-1)\n",
    "        i_min = diff.argmin(axis=-1)\n",
    "        return z, i_min, diff\n",
    "\n",
    "    def compute_rewards(\n",
    "        self,\n",
    "        dataset\n",
    "    ):\n",
    "        i0 = 0\n",
    "        rewards = []\n",
    "        observations = dataset.dataset_dict['observations']\n",
    "        episode_starts, episode_ends, episode_returns = gc_agent_dataset.dataset._trajectory_boundaries_and_returns()\n",
    "        \n",
    "        for i1 in tqdm(range(len(episode_starts))):\n",
    "            zi, start_index, diff = self.get_z_and_start_index(observations[episode_starts[i1]:episode_ends[i1]])\n",
    "            expert_demos = []\n",
    "            for i in range(self.expert_z.shape[0]):\n",
    "                traj = self.expert_z[i][start_index[i]:]\n",
    "                expert_demos.append(jnp.pad(traj, ((0, 1000 - traj.shape[0]), (0, 0)), mode='constant', constant_values=0.))\n",
    "            expert_demos = jnp.stack(expert_demos)\n",
    "            ri = self.compute_rewards_one_episode(zi, expert_demos)\n",
    "            rewards.append(jax.device_get(ri))\n",
    "                  \n",
    "        return np.concatenate(rewards)\n",
    "\n",
    "    @eqx.filter_vmap(in_axes=(None, None, 0))\n",
    "    def batched_ot(self, x, y):\n",
    "        geom = pointcloud.PointCloud(x, y, epsilon=0.001)\n",
    "        ot_prob = linear_problem.LinearProblem(geom)\n",
    "        solver = sinkhorn.Sinkhorn(max_iterations=200, use_danskin=True)\n",
    "\n",
    "        ot_sink = solver(ot_prob)\n",
    "        transp_cost = jnp.sum(ot_sink.matrix * geom.cost_matrix, axis=1)\n",
    "        rewards = 3 * jnp.exp(-1000 * 0.1 * 4 * transp_cost)\n",
    "        return rewards\n",
    "        \n",
    "    @eqx.filter_jit\n",
    "    def compute_rewards_one_episode(\n",
    "        self, episode_obs, expert_obs\n",
    "    ):\n",
    "        za_1 = episode_obs\n",
    "        za_2 = self.make_subs(za_1, self.sub_steps)\n",
    "        x = jnp.concatenate([za_1, za_2], axis=-1)\n",
    "        \n",
    "        ze_1 = expert_obs\n",
    "        ze_2 = self.make_subs(ze_1, self.sub_steps)\n",
    "        y = jnp.concatenate([ze_1, ze_2], axis=-1)\n",
    "        rewards = jnp.transpose(self.batched_ot(x, y), axes=(1, 0))\n",
    "        rewards = jnp.max(rewards, axis=-1)\n",
    "        return rewards"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "ExecuteTime": {
     "start_time": "2023-12-27T15:08:01.867Z"
    }
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "b089e16fe23d43ac9ec56138d96638ef",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/2000 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "expert = OTRewardsExpert(expert_trajectory)\n",
    "rewards = expert.compute_rewards(gc_agent_dataset.dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "load datafile: 100%|██████████████████████████████████████████████████████████████████████████████████| 9/9 [00:04<00:00,  1.81it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "acb2c1c826e1423daf926af757682b58",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/1998000 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "## Apply iql scaling\n",
    "from utils.ds_builder import load_trajectories\n",
    "    \n",
    "offline_traj = load_trajectories(env.spec.id, rewards) # scaled_rewards\n",
    "    \n",
    "def compute_iql_reward_scale(trajs):\n",
    "    \"\"\"Rescale rewards based on max/min from the dataset.\n",
    "    This is also used in the original IQL implementation.\n",
    "    \"\"\"\n",
    "    trajs = trajs.copy()\n",
    "    \n",
    "    def compute_returns(tr):\n",
    "        return sum([step[2] for step in tr])\n",
    "    \n",
    "    trajs.sort(key=compute_returns)\n",
    "    reward_scale = 1000.0 / (\n",
    "      compute_returns(trajs[-1]) - compute_returns(trajs[0]))\n",
    "    return reward_scale\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.dataset import Dataset\n",
    "\n",
    "ds = gc_agent_dataset.dataset.dataset_dict\n",
    "episode_starts, episode_ends, episode_returns = gc_agent_dataset.dataset._trajectory_boundaries_and_returns()\n",
    "data_with_ot_rewards = Dataset(\n",
    "    {'observations': np.concatenate([ds['observations'][episode_starts[i]:episode_ends[i]] for i, j in enumerate(episode_starts)]).astype(np.float32),\n",
    "    'next_observations': np.concatenate([ds['next_observations'][episode_starts[i]:episode_ends[i]] for i, j in enumerate(episode_starts)]).astype(np.float32),\n",
    "    'actions': np.concatenate([ds['actions'][episode_starts[i]:episode_ends[i]] for i, j in enumerate(episode_starts)]).astype(np.float32),\n",
    "    'rewards': rewards * compute_iql_reward_scale(offline_traj),\n",
    "    'masks': 1.0 - np.concatenate([ds['dones'][episode_starts[i]:episode_ends[i]] for i, j in enumerate(episode_starts)]).astype(np.float32),\n",
    "    })"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "Finishing last run (ID:0m0pawy8) before initializing another..."
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Waiting for W&B process to finish... <strong style=\"color:green\">(success).</strong>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "wandb: WARNING Source type is set to 'repo' but some required information is missing from the environment. A job will not be created from this run. See https://docs.wandb.ai/guides/launch/create-job\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "45f3f7813ea6428fa9e8003ba7453621",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "VBox(children=(Label(value='0.006 MB of 0.006 MB uploaded (0.000 MB deduped)\\r'), FloatProgress(value=1.0, max…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<style>\n",
       "    table.wandb td:nth-child(1) { padding: 0 10px; text-align: left ; width: auto;} td:nth-child(2) {text-align: left ; width: 100%}\n",
       "    .wandb-row { display: flex; flex-direction: row; flex-wrap: wrap; justify-content: flex-start; width: 100% }\n",
       "    .wandb-col { display: flex; flex-direction: column; flex-basis: 100%; flex: 1; padding: 10px; }\n",
       "    </style>\n",
       "<div class=\"wandb-row\"><div class=\"wandb-col\"><h3>Run history:</h3><br/><table class=\"wandb\"><tr><td>Eval/length</td><td>▁</td></tr><tr><td>Eval/return</td><td>▁</td></tr><tr><td>Training/actor_loss</td><td>█▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▂▁▁▁▁▂▁▁▁▂▂▁▁▁▂▁▁▁▁▁▃▂▂▁</td></tr><tr><td>Training/critic_loss</td><td>█▁▁▁▁▁▁▁▁▁▂▂▁▁▂▁▁▂▂▁▁▁▁▁▁▂▁▁▂▁▂▂▂▂▂▁▂▂▂▂</td></tr><tr><td>Training/q1</td><td>▁▂▂▃▃▄▄▄▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇▇█▇▇▇██▇▇█▇█▇████</td></tr><tr><td>Training/q2</td><td>▁▁▂▃▃▃▄▄▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇▇█▇▇▇██▇▇█▇█▇████</td></tr><tr><td>Training/v</td><td>▁▂▂▃▃▄▄▄▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇▇█▇▇▇██▇▇█▇█▇████</td></tr><tr><td>Training/value_loss</td><td>█▁▁▁▁▂▁▁▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▂▃▂▂▂▂▂▃▂▂▂▃▂▂▂</td></tr></table><br/></div><div class=\"wandb-col\"><h3>Run summary:</h3><br/><table class=\"wandb\"><tr><td>Eval/length</td><td>1000.0</td></tr><tr><td>Eval/return</td><td>5.29684</td></tr><tr><td>Training/actor_loss</td><td>-19.75833</td></tr><tr><td>Training/critic_loss</td><td>2.91662</td></tr><tr><td>Training/q1</td><td>28.30139</td></tr><tr><td>Training/q2</td><td>28.28683</td></tr><tr><td>Training/v</td><td>28.22572</td></tr><tr><td>Training/value_loss</td><td>0.1918</td></tr></table><br/></div></div>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       " View run <strong style=\"color:#cdcd00\">halfcheetah-medium-expert-v2</strong> at: <a href='https://wandb.ai/simmax21/ForPaper/runs/0m0pawy8' target=\"_blank\">https://wandb.ai/simmax21/ForPaper/runs/0m0pawy8</a><br/>Synced 5 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Find logs at: <code>./wandb/run-20240115_171349-0m0pawy8/logs</code>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Successfully finished last run (ID:0m0pawy8). Initializing new run:<br/>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e968cea7953c4717a5fd0bf60bff503b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "VBox(children=(Label(value='Waiting for wandb.init()...\\r'), FloatProgress(value=0.011112768699725469, max=1.0…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "wandb version 0.16.2 is available!  To upgrade, please run:\n",
       " $ pip install wandb --upgrade"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Tracking run with wandb version 0.15.12"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Run data is saved locally in <code>/home/m_bobrin/AILOT/wandb/run-20240115_172547-q2rga4q5</code>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Syncing run <strong><a href='https://wandb.ai/simmax21/ForPaper/runs/q2rga4q5' target=\"_blank\">halfcheetah-medium-expert-v2</a></strong> to <a href='https://wandb.ai/simmax21/ForPaper' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/run' target=\"_blank\">docs</a>)<br/>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       " View project at <a href='https://wandb.ai/simmax21/ForPaper' target=\"_blank\">https://wandb.ai/simmax21/ForPaper</a>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       " View run at <a href='https://wandb.ai/simmax21/ForPaper/runs/q2rga4q5' target=\"_blank\">https://wandb.ai/simmax21/ForPaper/runs/q2rga4q5</a>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<button onClick=\"this.nextSibling.style.display='block';this.style.display='none';\">Display W&B run</button><iframe src='https://wandb.ai/simmax21/ForPaper/runs/q2rga4q5?jupyter=true' style='border:none;width:100%;height:420px;display:none;'></iframe>"
      ],
      "text/plain": [
       "<wandb.sdk.wandb_run.Run at 0x7f824743d790>"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import wandb\n",
    "\n",
    "wandb.init(project='ForPaper', group='AILOT-MultiExpert', name=env.spec.id,\n",
    "          config={'max_steps': 1_000_000, 'seed': 10, 'expectile':0.7, 'discount': 0.99, 'temperature': 3})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "ExecuteTime": {
     "start_time": "2023-12-26T23:19:09.003Z"
    }
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "895c1410cfb64a0cbeb87b6c2da46da6",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/1000000 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'return': 1788.4620106100435, 'length': 1000.0}\n",
      "{'return': 1564.9741344102522, 'length': 1000.0}\n",
      "{'return': 3074.395431439654, 'length': 1000.0}\n",
      "{'return': 5116.340728355726, 'length': 1000.0}\n",
      "{'return': 4437.233463266752, 'length': 1000.0}\n",
      "{'return': 7609.667771745031, 'length': 1000.0}\n",
      "{'return': 8674.144393006174, 'length': 1000.0}\n",
      "{'return': 7154.076017003461, 'length': 1000.0}\n",
      "{'return': 8455.904761269496, 'length': 1000.0}\n",
      "{'return': 8208.20504998016, 'length': 1000.0}\n",
      "{'return': 8948.656609537169, 'length': 1000.0}\n",
      "{'return': 7861.44565258747, 'length': 1000.0}\n",
      "{'return': 10542.789847042826, 'length': 1000.0}\n",
      "{'return': 10516.319124328864, 'length': 1000.0}\n",
      "{'return': 9263.483884421821, 'length': 1000.0}\n",
      "{'return': 9383.809706123431, 'length': 1000.0}\n",
      "{'return': 10609.119187361732, 'length': 1000.0}\n",
      "{'return': 10453.537348100926, 'length': 1000.0}\n",
      "{'return': 10731.202918420337, 'length': 1000.0}\n"
     ]
    }
   ],
   "source": [
    "from src.agents.iql_flax.common import Batch\n",
    "from src.agents.iql_flax.learner import Learner\n",
    "from src.agents.iql_flax.evaluation import evaluate\n",
    "\n",
    "iql_agent_ot = Learner(\n",
    "        wandb.config.seed,\n",
    "        env.observation_space.sample()[np.newaxis],\n",
    "        env.action_space.sample()[np.newaxis],\n",
    "        max_steps=wandb.config.max_steps,\n",
    "        expectile=wandb.config.expectile,\n",
    "        discount=wandb.config.discount,\n",
    "        temperature=wandb.config.temperature)\n",
    "\n",
    "pbar = tqdm(range(wandb.config.max_steps))\n",
    "expert = OTRewardsExpert(expert_trajectory)\n",
    "\n",
    "for i in pbar:\n",
    "    sample = data_with_ot_rewards.sample(256)\n",
    "    batch = Batch(\n",
    "        observations=sample[\"observations\"],\n",
    "        next_observations=sample[\"next_observations\"],\n",
    "        actions = sample['actions'],\n",
    "        rewards= sample[\"rewards\"],\n",
    "        masks= sample[\"masks\"]\n",
    "    )\n",
    "    update_info = iql_agent_ot.update(batch)\n",
    "    update_info['adv'] = None \n",
    "    if i % 50_000 == 0 and i > 0:\n",
    "        eval_stats = evaluate(iql_agent_ot, env, num_episodes=10)\n",
    "        print(eval_stats)\n",
    "        eval_stats['return'] = env.get_normalized_score(eval_stats['return'])*100\n",
    "        wandb.log({f\"Eval/{key}\": value for key, value in eval_stats.items()})\n",
    "        pbar.set_postfix(update_info)\n",
    "    if i % 2000 == 0:\n",
    "        wandb.log({f\"Training/{key}\": value for key, value in update_info.items()})\n",
    "        pbar.set_postfix(update_info)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "84.36980536734823\n",
      "83.19496151175093\n",
      "86.49827431871417\n",
      "84.89019129751595\n",
      "78.6462337610193\n",
      "86.68024646434857\n",
      "79.93628108446032\n",
      "88.0105374728517\n",
      "86.84938015466824\n",
      "86.21227660891945\n",
      "87.4761072626642\n",
      "84.69437106296402\n",
      "84.66783427395559\n",
      "83.63897206648461\n",
      "87.66158863655635\n",
      "82.67445060191794\n",
      "86.27726331492916\n",
      "81.63361660289166\n",
      "81.74831678866148\n",
      "86.14787764592741\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "84.59542931492746"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "means = []\n",
    "for i in range(20):\n",
    "    eval_stats = evaluate(iql_agent_ot, env, num_episodes=10)\n",
    "    means.append(env.get_normalized_score(eval_stats['return'])*100)\n",
    "    print(env.get_normalized_score(eval_stats['return'])*100)\n",
    "    means.append(eval_stats['return'])\n",
    "    wandb.log({\"FinalEval/return\": env.get_normalized_score(eval_stats['return'])*100})\n",
    "np.asarray(means[0::2]).mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "2.6373492684934647"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.asarray(means[0::2]).std()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "Waiting for W&B process to finish... <strong style=\"color:green\">(success).</strong>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "wandb: WARNING Source type is set to 'repo' but some required information is missing from the environment. A job will not be created from this run. See https://docs.wandb.ai/guides/launch/create-job\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<style>\n",
       "    table.wandb td:nth-child(1) { padding: 0 10px; text-align: left ; width: auto;} td:nth-child(2) {text-align: left ; width: 100%}\n",
       "    .wandb-row { display: flex; flex-direction: row; flex-wrap: wrap; justify-content: flex-start; width: 100% }\n",
       "    .wandb-col { display: flex; flex-direction: column; flex-basis: 100%; flex: 1; padding: 10px; }\n",
       "    </style>\n",
       "<div class=\"wandb-row\"><div class=\"wandb-col\"><h3>Run history:</h3><br/><table class=\"wandb\"><tr><td>Eval/length</td><td>▁▁▁▁▁▁▁▁▁</td></tr><tr><td>Eval/return</td><td>▁▅▇▇█████</td></tr><tr><td>FinalEval/return</td><td>▄▅▅▂▁▆▅▂▄▅▅▂▆█▄▆▄▃▅▇</td></tr><tr><td>Training/actor_loss</td><td>▇▆▆▆▄▅▅▅▄▄▄▇▅▅▆▄▃▄▄█▄▃▄▃▃▅▅▄▃▄▃▄▁▃▂▃▅▃▁▂</td></tr><tr><td>Training/critic_loss</td><td>▅▄▄▂▂▂▂▂▂▅▂▂▂▃▂▁▃▅▅▃█▁▃▁▁▇▄▄▁▂▂▄▂▄▃▁▁▁▁▂</td></tr><tr><td>Training/q1</td><td>▁▄▇▇████████▇▇███▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇</td></tr><tr><td>Training/q2</td><td>▁▄▇▇████████▇████▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇</td></tr><tr><td>Training/v</td><td>▁▄▇███████████████▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇</td></tr><tr><td>Training/value_loss</td><td>▁▂▄▅▃▃▃▃▂▄▆▃▃▃▆▃▃▅▄▄▄▆▅▃▃▃▆▃▃▅█▄▄▄▄█▇▄▄▄</td></tr></table><br/></div><div class=\"wandb-col\"><h3>Run summary:</h3><br/><table class=\"wandb\"><tr><td>Eval/length</td><td>1000.0</td></tr><tr><td>Eval/return</td><td>47.19503</td></tr><tr><td>FinalEval/return</td><td>47.46241</td></tr><tr><td>Training/actor_loss</td><td>-4.29264</td></tr><tr><td>Training/critic_loss</td><td>4.31652</td></tr><tr><td>Training/q1</td><td>36.0746</td></tr><tr><td>Training/q2</td><td>36.04326</td></tr><tr><td>Training/v</td><td>35.88317</td></tr><tr><td>Training/value_loss</td><td>0.23533</td></tr></table><br/></div></div>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       " View run <strong style=\"color:#cdcd00\">halfcheetah-medium-v2</strong> at: <a href='https://wandb.ai/simmax21/ForPaper/runs/kndbtho4' target=\"_blank\">https://wandb.ai/simmax21/ForPaper/runs/kndbtho4</a><br/>Synced 5 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Find logs at: <code>./wandb/run-20240115_140619-kndbtho4/logs</code>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "wandb.finish()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-12-27T09:26:53.232731Z",
     "start_time": "2023-12-27T09:24:39.840882Z"
    }
   },
   "outputs": [],
   "source": [
    "from gymnasium.utils import save_video\n",
    "\n",
    "scales_shifts = [(6, 40), (-6, 230)]\n",
    "def scale_and_shift(x, lst):\n",
    "    return lst[0] * x + lst[1]\n",
    "    \n",
    "def plot_traj_image(traj, start, goal, subgoals, bgpath):\n",
    "    fig, ax = plt.subplots(figsize=(5, 5))\n",
    "    pimage = plt.imread(bgpath)\n",
    "    ax.imshow(pimage)\n",
    "    ax.set_xticks([])\n",
    "    ax.set_yticks([])\n",
    "    ax.scatter(scale_and_shift(expert_trajectory[:, 0], scales_shifts[0]), scale_and_shift(expert_trajectory[:, 1], scales_shifts[1]), alpha=1, label='trajectory', color='orange')\n",
    "    ax.scatter(scale_and_shift(start[0], scales_shifts[0]), scale_and_shift(start[1], scales_shifts[1]), c='g', s=100, label='start')\n",
    "    ax.scatter(scale_and_shift(goal[0], scales_shifts[0]), scale_and_shift(goal[1], scales_shifts[1]), c='r', s=100, label='goal')\n",
    "    ax.scatter(scale_and_shift(subgoals[0], scales_shifts[0]), scale_and_shift(subgoals[1], scales_shifts[1]), c='b', s=100, label='S_{t+k}')\n",
    "    ax.legend(fontsize=10)\n",
    "\n",
    "start_point = expert_trajectory[0]\n",
    "target_goal = expert_trajectory[-1]\n",
    "sample_key = jax.random.PRNGKey(42)\n",
    "\n",
    "env.reset()\n",
    "env.env.env.wrapped_env.set_xy((start_point[0], start_point[1]))\n",
    "env.env.env.wrapped_env.set_target((target_goal[0], target_goal[1]))\n",
    "start_point = env.env.env.wrapped_env._get_obs()\n",
    "curr_point = start_point\n",
    "frames=[]\n",
    "\n",
    "i = 0\n",
    "done = False\n",
    "while not done:\n",
    "    key, sample_key = jax.random.split(sample_key, 2)\n",
    "    action = jax.device_get(iql_agent_ot.sample_actions(curr_point.squeeze(), temperature=0.0))\n",
    "    new_obs, reward, done ,_ = env.step(action)\n",
    "    \n",
    "    os.environ['CUDA_VISIBLE_DEVICES']='4'\n",
    "    frames.append(env.render(mode='rgb_array'))\n",
    "    os.environ['CUDA_VISIBLE_DEVICES']='1,2,3'\n",
    "    if done:\n",
    "        print(reward)  \n",
    "    if i % 100 == 0:\n",
    "        plot_traj_image(sample, new_obs, target_goal, new_obs, \"/home/m_bobrin/AILOT/notebooks/antmaze-large.png\")\n",
    "\n",
    "    plt.show()\n",
    "    curr_point = new_obs\n",
    "    i+=1\n",
    "# save_video.save_video(frames, video_folder='.', fps=env.env.env._wrapped_env.metadata['render_fps'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "save_video.save_video(frames, video_folder='.', fps=25)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "wandb.finish()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Latent IQL"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "wandb.init(project='D4RL-jupyter', group='IQL-Our-Latent', name=env.spec.id,\n",
    "          config={'max_steps': 1_000_000, 'seed': 1337, 'expectile':0.9, 'discount': 0.999, 'temperature': 6})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.agents.iql_flax.common import Batch\n",
    "from src.agents.iql_flax.learner import Learner\n",
    "from src.agents.iql_flax.evaluation import evaluate\n",
    "\n",
    "z_sample = eval_ensemble_psi(icvf_model.value_learner.model, env.observation_space.sample()[np.newaxis]).mean(axis=0)\n",
    "iql_agent_ot_latent = Learner(\n",
    "        1337,\n",
    "        env.observation_space.sample()[np.newaxis],\n",
    "        z_sample,\n",
    "        max_steps=1_000_000,\n",
    "        expectile=0.9,\n",
    "        discount=0.999,\n",
    "        temperature=6,\n",
    "        latent=True)\n",
    "\n",
    "@eqx.filter_jit\n",
    "def get_z(obs):\n",
    "    return eval_ensemble_phi(icvf_model.value_learner.model, obs).mean(axis=0)\n",
    "\n",
    "pbar = tqdm(range(1_000_000))\n",
    "expert = OTRewardsExpert(expert_trajectory)\n",
    "\n",
    "for i in pbar:\n",
    "    sample = data_with_ot_rewards.sample(256)\n",
    "\n",
    "    if i % 3 == 0:\n",
    "        subs_i = sample[\"sub_observations_5\"]\n",
    "    elif i % 3 == 1:\n",
    "        subs_i = sample[\"sub_observations_5\"]\n",
    "    else:\n",
    "        subs_i = sample[\"sub_observations_10\"]\n",
    "\n",
    "    z = get_z(subs_i)\n",
    "    \n",
    "    batch = Batch(\n",
    "        observations=sample[\"observations\"],\n",
    "        next_observations=sample[\"next_observations\"],\n",
    "        actions = z / 10,\n",
    "        rewards= sample[\"rewards\"],\n",
    "        masks= sample[\"masks\"]\n",
    "    )\n",
    "    update_info = iql_agent_ot_latent.update(batch)\n",
    "    update_info['adv'] = None\n",
    "    if i % 3000 == 0:\n",
    "        expert_z = get_z(expert_trajectory)\n",
    "        expert_sub_z = expert.make_subs(expert_z, 5)\n",
    "        pred_subs = iql_agent_ot_latent.sample_actions(expert_trajectory, temperature=0) * 10\n",
    "        update_info[\"eval\"] = (jnp.abs(pred_subs - expert_sub_z).mean())\n",
    "        update_info['Vszz Adv'] = (eval_ensemble_icvf_latent_zzz(icvf_model.value_learner.model, pred_subs[0][None], expert_sub_z[0][None], expert_sub_z[0][None]) - eval_ensemble_icvf_latent_zz(icvf_model.value_learner.model,sample[\"observations\"][0][None], expert_sub_z[0][None], expert_sub_z[0][None])).mean(0)\n",
    "        wandb.log({'Training/Metrics': update_info})\n",
    "        pbar.set_postfix(update_info)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm.auto import tqdm\n",
    "from jaxrl_m.common import TrainStateEQX\n",
    "from src.agents.iql_equinox import GaussianPolicy, GaussianIntentPolicy\n",
    "\n",
    "def update_low_actor(iql_agent_ot_latent, batch_size, iter_num):\n",
    "    \n",
    "    key = jax.random.PRNGKey(42)\n",
    "    actor_learner = TrainStateEQX.create(\n",
    "            model=GaussianPolicy(key=key,\n",
    "                                 hidden_dims=[256, 256, 256],\n",
    "                                 state_dim=env.observation_space.shape[0],\n",
    "                                 intents_dim=256,\n",
    "                                 action_dim=env.action_space.shape[0]),\n",
    "            optim=optax.adam(learning_rate=3e-4)\n",
    "        )\n",
    "    def optimize_actor(actor_learner, sample, cur_goals_z, key):\n",
    "        v = eval_ensemble_icvf_latent_zz(icvf_model.value_learner.model, sample['observations'], cur_goals_z, cur_goals_z).mean(0)\n",
    "        nv = eval_ensemble_icvf_latent_zz(icvf_model.value_learner.model, sample['next_observations'], cur_goals_z, cur_goals_z).mean(0)\n",
    "        adv = nv - v\n",
    "        exp_a = jnp.minimum(jnp.exp(adv * 6.0), 100.0)\n",
    "        actor_dist = eqx.filter_vmap(actor_learner)(sample['observations'], cur_goals_z)\n",
    "        log_prob = actor_dist.log_prob(sample['actions']).sum(-1)\n",
    "        loss = -(exp_a * log_prob).mean()\n",
    "        return loss, adv.mean()\n",
    "\n",
    "    \n",
    "    actor_loss_fn = eqx.filter_jit(eqx.filter_value_and_grad)(optimize_actor, has_aux=True)\n",
    "    \n",
    "    @eqx.filter_jit\n",
    "    def make_step(actor_learner, sample, cur_goals_z, key):\n",
    "        (loss, adv), grads = actor_loss_fn(actor_learner.model, sample, cur_goals_z, key)\n",
    "        actor_learner = actor_learner.apply_updates(grads)\n",
    "        return actor_learner, loss, adv\n",
    "\n",
    "        \n",
    "    pbar = tqdm(range(iter_num))\n",
    "    for i in pbar:\n",
    "        sample = data_with_ot_rewards.sample(batch_size)\n",
    "        key, sample_key = jax.random.split(key, 2)\n",
    "        cur_goals_z = iql_agent_ot_latent.sample_actions(sample['observations'], temperature=0) * 10\n",
    "        actor_learner, loss, adv = make_step(actor_learner, sample, cur_goals_z, key=sample_key)\n",
    "        if i % 1000 == 0:\n",
    "            pbar.set_postfix({\"Loss\": loss, \"Adv\": adv})\n",
    "            \n",
    "    return actor_learner\n",
    "    \n",
    "updated_actor = update_low_actor(iql_agent_ot_latent, batch_size=256, iter_num=300_000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from gymnasium.utils import save_video\n",
    "\n",
    "scales_shifts = [(6, 40), (-6, 230)]\n",
    "def scale_and_shift(x, lst):\n",
    "    return lst[0] * x + lst[1]\n",
    "    \n",
    "def plot_traj_image(traj, start, goal, subgoals, bgpath):\n",
    "    fig, ax = plt.subplots(figsize=(5, 5))\n",
    "    pimage = plt.imread(bgpath)\n",
    "    ax.imshow(pimage)\n",
    "    ax.set_xticks([])\n",
    "    ax.set_yticks([])\n",
    "    ax.scatter(scale_and_shift(expert_trajectory[:, 0], scales_shifts[0]), scale_and_shift(expert_trajectory[:, 1], scales_shifts[1]), alpha=1, label='trajectory', color='orange')\n",
    "    ax.scatter(scale_and_shift(start[0], scales_shifts[0]), scale_and_shift(start[1], scales_shifts[1]), c='g', s=100, label='start')\n",
    "    ax.scatter(scale_and_shift(goal[0], scales_shifts[0]), scale_and_shift(goal[1], scales_shifts[1]), c='r', s=100, label='goal')\n",
    "    ax.scatter(scale_and_shift(subgoals[0], scales_shifts[0]), scale_and_shift(subgoals[1], scales_shifts[1]), c='b', s=100, label='S_{t+k}')\n",
    "    ax.legend(fontsize=10)\n",
    "\n",
    "start_point = expert_trajectory[0]\n",
    "target_goal = expert_trajectory[-1]\n",
    "sample_key = jax.random.PRNGKey(42)\n",
    "\n",
    "env.reset()\n",
    "env.env.env.wrapped_env.set_xy((start_point[0], start_point[1]))\n",
    "env.env.env.wrapped_env.set_target((target_goal[0], target_goal[1]))\n",
    "start_point = env.env.env.wrapped_env._get_obs()\n",
    "curr_point = start_point\n",
    "frames=[]\n",
    "\n",
    "i = 0\n",
    "done = False\n",
    "while not done:\n",
    "    key, sample_key = jax.random.split(sample_key, 2)\n",
    "    intent = jax.device_get(iql_agent_ot_latent.sample_actions(curr_point.squeeze(), temperature=0)) * 10\n",
    "    action = jax.device_get(updated_actor.model(curr_point.squeeze(), intent, temperature=0).mean())\n",
    "    new_obs, reward, done ,_ = env.step(action)\n",
    "    \n",
    "    os.environ['CUDA_VISIBLE_DEVICES']='4'\n",
    "    frames.append(env.render(mode='rgb_array'))\n",
    "    os.environ['CUDA_VISIBLE_DEVICES']='1,2,3'\n",
    "    if done:\n",
    "        print(reward)  \n",
    "    if i % 100 == 0:\n",
    "        plot_traj_image(sample, new_obs, target_goal, new_obs, \"/home/m_bobrin/AILOT/notebooks/antmaze-large.png\")\n",
    "\n",
    "    plt.show()\n",
    "    curr_point = new_obs\n",
    "    i+=1\n",
    "# save_video.save_video(frames, video_folder='.', fps=env.env.env._wrapped_env.metadata['render_fps'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# MUJOCO"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.gc_dataset import GCSDataset\n",
    "from utils.ds_builder import setup_datasets\n",
    "\n",
    "env, expert_ds, agent_ds, agent_mean_states, agent_std_states = setup_datasets(expert_env_name=\"walker2d-medium-replay-v2\",\n",
    "                                          agent_env_name=\"walker2d-medium-replay-v2\", expert_num=1,\n",
    "                                          normalize_agent_states=False)\n",
    "\n",
    "gcsds_params = GCSDataset.get_default_config()\n",
    "gc_expert_dataset = GCSDataset(expert_ds, **gcsds_params)\n",
    "gc_agent_dataset = GCSDataset(agent_ds, **gcsds_params)\n",
    "\n",
    "expert_trajectory = gc_expert_dataset.dataset.dataset_dict['observations'] #c_expert_dataset.get_expert_traj()['observations']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#%cd ..\n",
    "from src.agents import icvf\n",
    "icvf_model = icvf.create_eqx_learner(seed=9804,\n",
    "                                     observations=expert_ds.dataset_dict['observations'][0],\n",
    "                                     hidden_dims=[256, 256],\n",
    "                                     pretrained_folder=\"walker2d-medium-replay\",\n",
    "                                     load_pretrained_icvf=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "## IQL baseline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "wandb.init(project='D4RL-jupyter', group='IQL-base', name=env.spec.id,\n",
    "          config={'max_steps': 1_000_000, 'seed': 9804, 'expectile':0.95, 'discount': 0.99, 'temperature': 3}) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Apply iql scaling\n",
    "from utils.ds_builder import load_trajectories\n",
    "    \n",
    "offline_traj = load_trajectories(\"halfcheetah-medium-v2\", gc_agent_dataset.dataset.dataset_dict['rewards'])\n",
    "    \n",
    "def compute_iql_reward_scale(trajs):\n",
    "    \"\"\"Rescale rewards based on max/min from the dataset.\n",
    "    This is also used in the original IQL implementation.\n",
    "    \"\"\"\n",
    "    trajs = trajs.copy()\n",
    "    \n",
    "    def compute_returns(tr):\n",
    "        return sum([step[2] for step in tr])\n",
    "    \n",
    "    trajs.sort(key=compute_returns)\n",
    "    reward_scale = 1000.0 / (\n",
    "      compute_returns(trajs[-1]) - compute_returns(trajs[0]))\n",
    "    return reward_scale\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "iql_scale = compute_iql_reward_scale(offline_traj)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.agents.iql_flax.common import Batch\n",
    "from src.agents.iql_flax.learner import Learner\n",
    "from src.agents.iql_flax.evaluation import evaluate\n",
    "\n",
    "max_steps = 1_000_000\n",
    "\n",
    "iql_agent = Learner(\n",
    "        9804,\n",
    "        env.observation_space.sample()[np.newaxis],\n",
    "        env.action_space.sample()[np.newaxis],\n",
    "        max_steps=max_steps,\n",
    "        expectile=0.95,\n",
    "        discount=0.99,\n",
    "        temperature=3)\n",
    "\n",
    "pbar = tqdm(range(max_steps))\n",
    "for i in pbar:\n",
    "    sample = gc_agent_dataset.dataset.sample(batch_size=256)\n",
    "    batch = Batch(\n",
    "        observations=sample[\"observations\"],\n",
    "        next_observations=sample[\"next_observations\"],\n",
    "        actions = sample[\"actions\"],\n",
    "        rewards = sample[\"rewards\"] * iql_scale, \n",
    "        masks= sample[\"masks\"]\n",
    "    )\n",
    "    update_info = iql_agent.update(batch)\n",
    "    update_info['adv'] = None\n",
    "    if i % 50_000 == 0 and i > 0:\n",
    "        eval_stats = evaluate(iql_agent, env, num_episodes=10)\n",
    "        print(eval_stats)\n",
    "        eval_stats['return'] = env.get_normalized_score(eval_stats['return'])*100\n",
    "        wandb.log({'Eval': eval_stats})\n",
    "        pbar.set_postfix(update_info)\n",
    "    if i % 3000 == 0:\n",
    "        wandb.log({'Training/': update_info})\n",
    "        pbar.set_postfix(update_info)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from gymnasium.utils import save_video\n",
    "\n",
    "frames=[]\n",
    "i = 0\n",
    "num_episodes = 1 # 1 for render\n",
    "all_reward = []\n",
    "key = jax.random.PRNGKey(42)\n",
    "\n",
    "for i in range(num_episodes):\n",
    "    episode_reward = 0\n",
    "    key, sample_key = jax.random.split(key, 2)\n",
    "    obs = env.reset()\n",
    "    done = False\n",
    "    while not done:\n",
    "        key, sample_key = jax.random.split(sample_key, 2)\n",
    "        action = jax.device_get(iql_agent.sample_actions(obs, temperature=0.0))\n",
    "        obs, reward, done ,_ = env.step(action)\n",
    "        os.environ['CUDA_VISIBLE_DEVICES']='4'\n",
    "        frames.append(env.render(mode='rgb_array'))\n",
    "        os.environ['CUDA_VISIBLE_DEVICES']='0,1,2,3'\n",
    "        episode_reward += reward\n",
    "    all_reward.append(episode_reward)\n",
    "    print(episode_reward)\n",
    "save_video.save_video(frames, video_folder='.', fps=20)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "wandb.finish()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Our Algo"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from abc import ABC, abstractmethod\n",
    "from jax.numpy import ndarray\n",
    "from ott.geometry import costs\n",
    "from ott.math import utils as mu\n",
    "\n",
    "@jax.tree_util.register_pytree_node_class\n",
    "class MyCost(costs.CostFn):\n",
    "    def __init__(self) -> None:\n",
    "        super().__init__()\n",
    "        self.cost = costs.SqEuclidean()\n",
    "\n",
    "    def pairwise(self, x: ndarray, y: ndarray) -> float:\n",
    "        d = self.cost(x, y)\n",
    "        return jnp.minimum(5000, d)\n",
    "\n",
    "class OTRewardsExpert:\n",
    "\n",
    "    def __init__(\n",
    "        self, expert_traj,\n",
    "    ):\n",
    "        self.expert_states = expert_traj\n",
    "        self.expert_z = eval_ensemble_psi(icvf_model.value_learner.model, expert_traj).mean(axis=0)\n",
    "        self.sub_steps = 5\n",
    "\n",
    "    def make_subs(self, z, sub_steps):\n",
    "        sub_indx = jnp.minimum(jnp.arange(0, z.shape[0]) + sub_steps, z.shape[0] - 1)\n",
    "        return jax.tree_map(lambda arr: arr[sub_indx], z)\n",
    "    \n",
    "\n",
    "    @eqx.filter_jit\n",
    "    def get_z_and_start_index(self, obs):\n",
    "        # obs - trajectory\n",
    "        z = eval_ensemble_psi(icvf_model.value_learner.model, obs).mean(axis=0)\n",
    "        diff = jnp.linalg.norm(z[0][jnp.newaxis,] - self.expert_z, axis=-1) #eqx_get_state_traj(icvf_model.value_learner.model, z[0][None], self.expert_z).mean(1)#z[0][jnp.newaxis,] - self.expert_z\n",
    "        i_min = jnp.argmin(diff)#jnp.argmin((diff**2).sum(-1)).squeeze()\n",
    "        return z, i_min, diff\n",
    "\n",
    "    def compute_rewards(\n",
    "        self,\n",
    "        dataset\n",
    "    ):\n",
    "        i0 = 0\n",
    "        rewards = []\n",
    "        observations = dataset.dataset_dict['observations']\n",
    "        episode_starts, episode_ends, episode_returns = gc_agent_dataset.dataset._trajectory_boundaries_and_returns()\n",
    "        \n",
    "        for i1 in tqdm(range(len(episode_starts))):\n",
    "            zi, start_index, diff = self.get_z_and_start_index(observations[episode_starts[i1]:episode_ends[i1]])\n",
    "            ri = self.compute_rewards_one_episode(zi, self.expert_z[start_index:])\n",
    "            #print(eval_ensemble_icvf_latent_zzz(icvf_model.value_learner.model, zi[0][None], self.expert_z[start_index][5][None], self.expert_z[start_index][5][None]).mean(0))\n",
    "            rewards.append(jax.device_get(ri))\n",
    "                  \n",
    "        return np.concatenate(rewards)#, selected_index\n",
    "\n",
    "    @eqx.filter_jit\n",
    "    def compute_rewards_one_episode(\n",
    "        self, episode_obs, expert_obs\n",
    "    ):\n",
    "\n",
    "        za_1 = episode_obs\n",
    "        za_2 = self.make_subs(za_1, self.sub_steps)\n",
    "        x = jnp.concatenate([za_1, za_2], axis=1)\n",
    "\n",
    "        ze_1 = expert_obs\n",
    "        ze_2 = self.make_subs(ze_1, self.sub_steps)\n",
    "        y = jnp.concatenate([ze_1, ze_2], axis=1)\n",
    "        \n",
    "        geom = pointcloud.PointCloud(x, y, epsilon=0.001)\n",
    "        ot_prob = linear_problem.LinearProblem(geom)\n",
    "        solver = sinkhorn.Sinkhorn(max_iterations=200, use_danskin=True)\n",
    "\n",
    "        ot_sink = solver(ot_prob)\n",
    "        transp_cost = jnp.sum(ot_sink.matrix * geom.cost_matrix, axis=1)\n",
    "        rewards = -transp_cost * episode_obs.shape[0] / 10\n",
    "\n",
    "        return rewards\n",
    "        \n",
    "# expert = OTRewardsExpert(expert_trajectory)\n",
    "# rewards = expert.compute_rewards(gc_agent_dataset.dataset)\n",
    "\n",
    "from src.dataset import Dataset\n",
    "\n",
    "class ExpRewardsScaler:\n",
    "    def init(self, rewards: np.ndarray):\n",
    "        self.min = np.quantile(np.abs(rewards).reshape(-1), 0.0)\n",
    "        self.max = np.quantile(np.abs(rewards).reshape(-1), 0.95)\n",
    "\n",
    "    def scale(self, rewards: np.ndarray):\n",
    "        # From paper\n",
    "        return 5* np.exp(5 * rewards)\n",
    "\n",
    "\n",
    "def get_subs(dataset: GCSDataset, add_steps: int):\n",
    "    terminal_locs = dataset.terminal_locs\n",
    "    indx = np.arange(dataset.dataset.dataset_dict['observations'].shape[0])\n",
    "    final_state_indx = terminal_locs[np.searchsorted(terminal_locs, indx)] \n",
    "    way_indx = np.minimum(indx + add_steps, final_state_indx)\n",
    "    subs = jax.tree_map(lambda arr: arr[way_indx], dataset.dataset.dataset_dict['observations'])\n",
    "    return subs\n",
    "\n",
    "scaler = ExpRewardsScaler()\n",
    "scaled_rewards = scaler.scale(rewards).astype(np.float32)\n",
    "\n",
    "## Apply iql scaling\n",
    "from utils.ds_builder import load_trajectories\n",
    "    \n",
    "offline_traj = load_trajectories(env.spec.id, scaled_rewards)\n",
    "    \n",
    "def compute_iql_reward_scale(trajs):\n",
    "    \"\"\"Rescale rewards based on max/min from the dataset.\n",
    "    This is also used in the original IQL implementation.\n",
    "    \"\"\"\n",
    "    trajs = trajs.copy()\n",
    "    \n",
    "    def compute_returns(tr):\n",
    "        return sum([step[2] for step in tr])\n",
    "    \n",
    "    trajs.sort(key=compute_returns)\n",
    "    reward_scale = 1000.0 / (\n",
    "      compute_returns(trajs[-1]) - compute_returns(trajs[0]))\n",
    "    return reward_scale\n",
    "    \n",
    "subs_15 = get_subs(gc_agent_dataset, 15)\n",
    "subs_10 = get_subs(gc_agent_dataset, 10)\n",
    "subs_5 = get_subs(gc_agent_dataset, 5)\n",
    "\n",
    "ds = gc_agent_dataset.dataset.dataset_dict\n",
    "episode_starts, episode_ends, episode_returns = gc_agent_dataset.dataset._trajectory_boundaries_and_returns()\n",
    "data_with_ot_rewards = Dataset(\n",
    "    {'observations': np.concatenate([ds['observations'][episode_starts[i]:episode_ends[i]] for i, j in enumerate(episode_starts)]).astype(np.float32),#[scaled_rewards > r_min],\n",
    "    'next_observations': np.concatenate([ds['next_observations'][episode_starts[i]:episode_ends[i]] for i, j in enumerate(episode_starts)]).astype(np.float32),#[scaled_rewards > r_min],\n",
    "    'actions': np.concatenate([ds['actions'][episode_starts[i]:episode_ends[i]] for i, j in enumerate(episode_starts)]).astype(np.float32),#[scaled_rewards > r_min],\n",
    "    'rewards':scaled_rewards * compute_iql_reward_scale(offline_traj),\n",
    "    'masks': 1.0 - np.concatenate([ds['dones'][episode_starts[i]:episode_ends[i]] for i, j in enumerate(episode_starts)]).astype(np.float32),#[scaled_rewards > r_min],\n",
    "    'sub_observations_5': np.concatenate([subs_5[episode_starts[i]:episode_ends[i]] for i, j in enumerate(episode_starts)]).astype(np.float32),#[scaled_rewards > r_min], \n",
    "    'sub_observations_10':np.concatenate([subs_10[episode_starts[i]:episode_ends[i]] for i, j in enumerate(episode_starts)]).astype(np.float32),#[scaled_rewards > r_min], \n",
    "    'sub_observations_15': np.concatenate([subs_15[episode_starts[i]:episode_ends[i]] for i, j in enumerate(episode_starts)]).astype(np.float32)})#[scaled_rewards > r_min]})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#%wandb offline\n",
    "wandb.init(project='ForPaper', group='AILOT-Mujoco', name=env.spec.id,\n",
    "          config={'max_steps': 1_000_000, 'seed': 5001, 'expectile':0.5, 'discount': 0.99, 'temperature': 1})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.agents.iql_flax.common import Batch\n",
    "from src.agents.iql_flax.learner import Learner\n",
    "from src.agents.iql_flax.evaluation import evaluate\n",
    "\n",
    "iql_agent_ot = Learner(\n",
    "        1337,\n",
    "        env.observation_space.sample()[np.newaxis],\n",
    "        env.action_space.sample()[np.newaxis],\n",
    "        max_steps=wandb.config.max_steps,\n",
    "        expectile=wandb.config.expectile,\n",
    "        discount=wandb.config.discount,\n",
    "        temperature=wandb.config.temperature)\n",
    "\n",
    "pbar = tqdm(range(1_000_000 + 1))\n",
    "expert = OTRewardsExpert(expert_trajectory)\n",
    "\n",
    "for i in pbar:\n",
    "    sample = data_with_ot_rewards.sample(256)\n",
    "    batch = Batch(\n",
    "        observations=sample[\"observations\"],\n",
    "        next_observations=sample[\"next_observations\"],\n",
    "        actions = sample['actions'],\n",
    "        rewards= sample[\"rewards\"],\n",
    "        masks= sample[\"masks\"]\n",
    "    )\n",
    "    update_info = iql_agent_ot.update(batch)\n",
    "    update_info['adv'] = None\n",
    "    if i % 50_000 == 0 and i > 0:\n",
    "        eval_stats = evaluate(iql_agent_ot, env, num_episodes=10)\n",
    "        print(eval_stats)\n",
    "        eval_stats['return'] = env.get_normalized_score(eval_stats['return'])*100\n",
    "        wandb.log({f\"Eval/{key}\": value for key, value in eval_stats.items()})\n",
    "        pbar.set_postfix(update_info)\n",
    "    if i % 2000 == 0:\n",
    "        wandb.log({f\"Training/{key}\": value for key, value in update_info.items()})\n",
    "        pbar.set_postfix(update_info)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "eval_stats = evaluate(iql_agent_ot, env, num_episodes=50)\n",
    "eval_stats['return'] = env.get_normalized_score(eval_stats['return'])*100\n",
    "wandb.log({f\"FinalEval/{key}\": value for key, value in eval_stats.items()})\n",
    "print(eval_stats)\n",
    "wandb.finish()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "wandb.finish()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.16"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
