{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-12-27T15:08:00.518601Z",
     "start_time": "2023-12-27T15:07:58.184122Z"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "<frozen importlib._bootstrap>:283: DeprecationWarning: the load_module() method is deprecated and slated for removal in Python 3.12; use exec_module() instead\n",
      "pybullet build time: Nov 28 2023 23:45:17\n",
      "/home/m_bobrin/anaconda3/envs/jax/lib/python3.10/site-packages/pybullet_envs/env_bases.py:8: DeprecationWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html\n",
      "  from pkg_resources import parse_version\n",
      "/home/m_bobrin/anaconda3/envs/jax/lib/python3.10/site-packages/flax/core/meta.py:31: DeprecationWarning: jax.experimental.maps and jax.experimental.maps.xmap are deprecated and will be removed in a future release. Use jax.experimental.shard_map or jax.vmap with the spmd_axis_name argument for expressing SPMD device-parallel computations. Please file an issue on https://github.com/google/jax/issues if neither jax.experimental.shard_map nor jax.vmap are suitable for your use case.\n",
      "  from jax.experimental import maps\n"
     ]
    }
   ],
   "source": [
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "\n",
    "import sys\n",
    "sys.path.append(\"..\")\n",
    "\n",
    "import os\n",
    "os.environ['MUJOCO_GL']='egl'\n",
    "os.environ['D4RL_SUPPRESS_IMPORT_ERROR'] = '1'\n",
    "os.environ['CUDA_VISIBLE_DEVICES']='1'\n",
    "\n",
    "import wandb\n",
    "import gym\n",
    "import d4rl\n",
    "\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "plt.style.use('fivethirtyeight')\n",
    "\n",
    "import equinox as eqx\n",
    "import jax\n",
    "import jax.numpy as jnp\n",
    "import functools\n",
    "\n",
    "from tqdm.auto import tqdm\n",
    "\n",
    "from ott.geometry import pointcloud\n",
    "from ott.problems.linear import linear_problem\n",
    "from ott.solvers.linear import sinkhorn\n",
    "\n",
    "import optax\n",
    "\n",
    "@eqx.filter_vmap(in_axes=dict(ensemble=eqx.if_array(0), s=None))\n",
    "def eval_ensemble_psi(ensemble, s):\n",
    "    return eqx.filter_vmap(ensemble.psi_net)(s)\n",
    "\n",
    "@eqx.filter_vmap(in_axes=dict(ensemble=eqx.if_array(0), s=None))\n",
    "def eval_ensemble_phi(ensemble, s):\n",
    "    return eqx.filter_vmap(ensemble.phi_net)(s)\n",
    "\n",
    "@eqx.filter_vmap(in_axes=dict(ensemble=eqx.if_array(0), s=None, g=None, z=None))\n",
    "def eval_ensemble_icvf_viz(ensemble, s, g, z):\n",
    "    return eqx.filter_vmap(ensemble.classic_icvf_initial)(s, g, z)\n",
    "\n",
    "@eqx.filter_vmap(in_axes=dict(ensemble=eqx.if_array(0), s=None, g=None, z=None)) # V(s, g, z), g - dim 29, z - dim 256\n",
    "def eval_ensemble_icvf_latent_z(ensemble, s, g, z):\n",
    "    return eqx.filter_vmap(ensemble.classic_icvf)(s, g, z)\n",
    "\n",
    "@eqx.filter_vmap(in_axes=dict(ensemble=eqx.if_array(0), s=None, g=None, z=None)) # V(s, g ,z ), g, z - dim 256\n",
    "def eval_ensemble_icvf_latent_zz(ensemble, s, g, z):\n",
    "    return eqx.filter_vmap(ensemble.icvf_zz)(s, g, z)\n",
    "    \n",
    "@eqx.filter_vmap(in_axes=dict(ensemble=eqx.if_array(0), s=None, g=None, z=None))\n",
    "def eval_ensemble_icvf_latent_zzz(ensemble, s, g, z):\n",
    "    return eqx.filter_vmap(ensemble.icvf_zzz)(s, g, z)\n",
    "\n",
    "@eqx.filter_vmap(in_axes=dict(ensemble=None, s=0, z=None))\n",
    "def eqx_get_state_traj(ensemble, s, z):\n",
    "    '''\n",
    "    Function to compute pairwise distance between two trajectories\n",
    "    '''\n",
    "    s = jnp.tile(s, (z.shape[0], 1))\n",
    "    return eval_ensemble_icvf_latent_zzz(icvf_model.value_learner.model, s, z, z)\n",
    "\n",
    "@eqx.filter_jit\n",
    "def get_gcvalue(agent, s, g, z):\n",
    "    v_sgz_1, v_sgz_2 = eval_ensemble_icvf_viz(agent.value_learner.model, s, g, z)\n",
    "    return (v_sgz_1 + v_sgz_2) / 2\n",
    "\n",
    "def get_v_gz(agent, initial_state, target_goal, observations):\n",
    "    initial_state = jnp.tile(initial_state, (observations.shape[0], 1))\n",
    "    target_goal = jnp.tile(target_goal, (observations.shape[0], 1))\n",
    "    return -1 * get_gcvalue(agent, initial_state, observations, target_goal)\n",
    "    \n",
    "def get_v_zz(agent, goal, observations):\n",
    "    goal = jnp.tile(goal, (observations.shape[0], 1))\n",
    "    return get_gcvalue(agent, observations, goal, goal)\n",
    "\n",
    "@eqx.filter_vmap(in_axes=dict(agent=None, obs=None, goal=0))\n",
    "def get_v_zz_heatmap(agent, obs, goal): # goal - traj\n",
    "    goal = jnp.tile(goal, (obs.shape[0], 1))\n",
    "    return get_gcvalue(agent, obs, goal, goal)\n",
    "\n",
    "@eqx.filter_vmap(in_axes=dict(ensemble=None, s=0))\n",
    "@eqx.filter_vmap(in_axes=dict(ensemble=eqx.if_array(0), s=None))\n",
    "def batched_eval_ensemble_psi(ensemble, s):\n",
    "    return eqx.filter_vmap(ensemble.psi_net)(s)\n",
    "\n",
    "jnp.set_printoptions(precision=2, suppress=True)\n",
    "\n",
    "%matplotlib inline\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "ExecuteTime": {
     "start_time": "2023-12-27T15:07:58.310Z"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-05-22 20:32:57.470689: W external/xla/xla/service/gpu/nvptx_compiler.cc:718] The NVIDIA driver's CUDA version is 12.3 which is older than the ptxas CUDA version (12.4.99). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.\n",
      "/home/m_bobrin/anaconda3/envs/jax/lib/python3.10/site-packages/gym/core.py:172: DeprecationWarning: \u001b[33mWARN: Function `env.seed(seed)` is marked as deprecated and will be removed in the future. Please use `env.reset(seed=seed) instead.\u001b[0m\n",
      "  deprecation(\n",
      "load datafile: 100%|██████████| 9/9 [00:02<00:00,  3.77it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a04f271a857b43a3849c25edcb08fc28",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/1998000 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Expert returns [11086.897860401426, 11124.63724167645, 11092.346300512552, 11188.732963606715, 11239.283746674657], mean 11146.37962257436\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "load datafile: 100%|██████████| 11/11 [00:00<00:00, 46.56it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Number of terminal states: 1\n",
      "Number of terminal states: 203\n"
     ]
    }
   ],
   "source": [
    "from src.gc_dataset import GCSDataset\n",
    "from utils.ds_builder import setup_datasets\n",
    "\n",
    "env, expert_ds, agent_ds, agent_mean_states, agent_std_states = setup_datasets(expert_env_name=\"halfcheetah-medium-expert-v2\",\n",
    "                                          agent_env_name=\"halfcheetah-medium-replay-v2\", expert_num=5,\n",
    "                                          normalize_agent_states=False)\n",
    "\n",
    "gcsds_params = GCSDataset.get_default_config()\n",
    "gc_expert_dataset = GCSDataset(expert_ds, **gcsds_params)\n",
    "gc_agent_dataset = GCSDataset(agent_ds, **gcsds_params)\n",
    "\n",
    "expert_trajectory = gc_expert_dataset.dataset.dataset_dict['observations']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/home/m_bobrin/AILOT\n",
      "Extra kwargs: {}\n"
     ]
    }
   ],
   "source": [
    "from src.agents import icvf\n",
    "%cd ..\n",
    "icvf_model = icvf.create_eqx_learner(seed=42,\n",
    "                                     observations=gc_agent_dataset.dataset.dataset_dict['observations'][0],\n",
    "                                     hidden_dims=[256, 256],\n",
    "                                     pretrained_folder=\"halfcheetah-medium-expert\",\n",
    "                                     load_pretrained_icvf=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "# IQL Baseline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "ERROR:wandb.jupyter:Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/m_bobrin/anaconda3/envs/jax/lib/python3.10/site-packages/wandb/sdk/lib/ipython.py:77: DeprecationWarning: Importing display from IPython.core.display is deprecated since IPython 7.14, please import from IPython display\n",
      "  from IPython.core.display import HTML, display  # type: ignore\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "Tracking run with wandb version 0.16.4"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "W&B syncing is set to <code>`offline`<code> in this directory.  <br/>Run <code>`wandb online`<code> or set <code>WANDB_MODE=online<code> to enable cloud syncing."
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<button onClick=\"this.nextSibling.style.display='block';this.style.display='none';\">Display W&B run</button><iframe src='?jupyter=true' style='border:none;width:100%;height:420px;display:none;'></iframe>"
      ],
      "text/plain": [
       "<wandb.sdk.wandb_run.Run at 0x7936b3f63f70>"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "wandb.init(project='D4RL-jupyter', group='IQL-base', name=env.spec.id,mode='offline',\n",
    "          config={'max_steps': 1_000_000, 'seed': 329399, 'expectile':0.9, 'discount': 0.99, 'temperature': 3})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "318fdbc9e73c4225b9e992caa5121dcb",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/1000000 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/m_bobrin/AILOT/notebooks/../src/agents/iql_flax/learner.py:19: DeprecationWarning: jax.tree_map is deprecated: use jax.tree.map (jax v0.4.25 or newer) or jax.tree_util.tree_map (any JAX version).\n",
      "  new_target_params = jax.tree_map(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'return': 3149.534634436598, 'length': 1000.0}\n",
      "{'return': 4131.887007751274, 'length': 1000.0}\n",
      "{'return': 4286.3978851200245, 'length': 1000.0}\n",
      "{'return': 4976.290986408293, 'length': 1000.0}\n",
      "{'return': 5086.705687927414, 'length': 1000.0}\n",
      "{'return': 5060.435645161143, 'length': 1000.0}\n",
      "{'return': 5080.665576175302, 'length': 1000.0}\n",
      "{'return': 5103.258153680383, 'length': 1000.0}\n",
      "{'return': 5169.252698617677, 'length': 1000.0}\n",
      "{'return': 5084.815931102232, 'length': 1000.0}\n",
      "{'return': 5270.085662506708, 'length': 1000.0}\n",
      "{'return': 5191.690483569702, 'length': 1000.0}\n",
      "{'return': 5258.913090504395, 'length': 1000.0}\n",
      "{'return': 4993.8611321546905, 'length': 1000.0}\n",
      "{'return': 5242.375887033695, 'length': 1000.0}\n",
      "{'return': 5172.271450205778, 'length': 1000.0}\n",
      "{'return': 5333.186050278297, 'length': 1000.0}\n",
      "{'return': 5249.911500189888, 'length': 1000.0}\n",
      "{'return': 5145.692304728164, 'length': 1000.0}\n"
     ]
    }
   ],
   "source": [
    "from src.agents.iql_flax.common import Batch\n",
    "from src.agents.iql_flax.learner import Learner\n",
    "from src.agents.iql_flax.evaluation import evaluate\n",
    "\n",
    "max_steps = wandb.config.max_steps\n",
    "batch_size = 256\n",
    "\n",
    "\n",
    "iql_agent = Learner(\n",
    "        wandb.config.seed,\n",
    "        env.observation_space.sample()[np.newaxis],\n",
    "        env.action_space.sample()[np.newaxis],\n",
    "        max_steps=max_steps,\n",
    "        expectile=wandb.config.expectile,\n",
    "        discount=wandb.config.discount,\n",
    "        temperature=wandb.config.temperature)\n",
    "\n",
    "pbar = tqdm(range(max_steps))\n",
    "for i in pbar:\n",
    "    sample = gc_agent_dataset.dataset.sample(batch_size=256)\n",
    "    batch = Batch(\n",
    "        observations=sample[\"observations\"],\n",
    "        next_observations=sample[\"next_observations\"],\n",
    "        actions = sample[\"actions\"],\n",
    "        rewards = sample[\"rewards\"], \n",
    "        masks= sample[\"masks\"]\n",
    "    )\n",
    "    update_info = iql_agent.update(batch)\n",
    "    update_info['adv'] = None\n",
    "    \n",
    "    if i % 50_000 == 0 and i > 0:\n",
    "        eval_stats = evaluate(iql_agent, env, num_episodes=10)\n",
    "        wandb.log({'Eval': eval_stats})\n",
    "        print(eval_stats)\n",
    "        pbar.set_postfix(update_info)\n",
    "        \n",
    "    if i % 3000 == 0:\n",
    "        wandb.log({'Training/': update_info})\n",
    "        pbar.set_postfix(update_info)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'return': 5259.223621283285, 'length': 1000.0}\n"
     ]
    }
   ],
   "source": [
    "eval_stats = evaluate(iql_agent, env, num_episodes=20)\n",
    "print(eval_stats)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "ename": "AttributeError",
     "evalue": "'OfflineHalfCheetahEnv' object has no attribute 'set_xy'",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mAttributeError\u001b[0m                            Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[7], line 25\u001b[0m\n\u001b[1;32m     22\u001b[0m sample_key \u001b[38;5;241m=\u001b[39m jax\u001b[38;5;241m.\u001b[39mrandom\u001b[38;5;241m.\u001b[39mPRNGKey(\u001b[38;5;241m42\u001b[39m)\n\u001b[1;32m     24\u001b[0m env\u001b[38;5;241m.\u001b[39mreset()\n\u001b[0;32m---> 25\u001b[0m \u001b[43menv\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43menv\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43menv\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mwrapped_env\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mset_xy\u001b[49m((start_point[\u001b[38;5;241m0\u001b[39m], start_point[\u001b[38;5;241m1\u001b[39m]))\n\u001b[1;32m     26\u001b[0m env\u001b[38;5;241m.\u001b[39menv\u001b[38;5;241m.\u001b[39menv\u001b[38;5;241m.\u001b[39mwrapped_env\u001b[38;5;241m.\u001b[39mset_target((target_goal[\u001b[38;5;241m0\u001b[39m], target_goal[\u001b[38;5;241m1\u001b[39m]))\n\u001b[1;32m     27\u001b[0m start_point \u001b[38;5;241m=\u001b[39m env\u001b[38;5;241m.\u001b[39menv\u001b[38;5;241m.\u001b[39menv\u001b[38;5;241m.\u001b[39mwrapped_env\u001b[38;5;241m.\u001b[39m_get_obs()\n",
      "\u001b[0;31mAttributeError\u001b[0m: 'OfflineHalfCheetahEnv' object has no attribute 'set_xy'"
     ]
    }
   ],
   "source": [
    "from gymnasium.utils import save_video\n",
    "\n",
    "scales_shifts = [(6, 40), (-6, 230)]\n",
    "def scale_and_shift(x, lst):\n",
    "    return lst[0] * x + lst[1]\n",
    "    \n",
    "def plot_traj_image(traj, start, goal, subgoals, bgpath):\n",
    "    fig, ax = plt.subplots(figsize=(5, 5))\n",
    "    pimage = plt.imread(bgpath)\n",
    "    ax.imshow(pimage)\n",
    "    ax.set_xticks([])\n",
    "    ax.set_yticks([])\n",
    "    ax.scatter(scale_and_shift(expert_trajectory[:, 0], scales_shifts[0]), scale_and_shift(expert_trajectory[:, 1], scales_shifts[1]), alpha=1, label='trajectory', color='orange')\n",
    "    ax.scatter(scale_and_shift(start[0], scales_shifts[0]), scale_and_shift(start[1], scales_shifts[1]), c='g', s=100, label='start')\n",
    "    ax.scatter(scale_and_shift(goal[0], scales_shifts[0]), scale_and_shift(goal[1], scales_shifts[1]), c='r', s=100, label='goal')\n",
    "    ax.scatter(scale_and_shift(subgoals[0], scales_shifts[0]), scale_and_shift(subgoals[1], scales_shifts[1]), c='b', s=100, label='S_{t+k}')\n",
    "    ax.legend(fontsize=10)\n",
    "\n",
    "sample = gc_agent_dataset.dataset.sample(1)\n",
    "start_point = expert_trajectory[0]\n",
    "target_goal = expert_trajectory[-1]\n",
    "sample_key = jax.random.PRNGKey(42)\n",
    "\n",
    "env.reset()\n",
    "env.env.env.wrapped_env.set_xy((start_point[0], start_point[1]))\n",
    "env.env.env.wrapped_env.set_target((target_goal[0], target_goal[1]))\n",
    "start_point = env.env.env.wrapped_env._get_obs()\n",
    "curr_point = start_point\n",
    "frames=[]\n",
    "\n",
    "i = 0\n",
    "done = False\n",
    "while not done:\n",
    "    key, sample_key = jax.random.split(sample_key, 2)\n",
    "    action = jax.device_get(iql_agent.sample_actions(curr_point.squeeze(), temperature=0.0))\n",
    "    new_obs, reward, done ,_ = env.step(action)\n",
    "    \n",
    "    os.environ['CUDA_VISIBLE_DEVICES']='4'\n",
    "    frames.append(env.render(mode='rgb_array'))\n",
    "    os.environ['CUDA_VISIBLE_DEVICES']='1,2,3'\n",
    "    if done:\n",
    "        print(reward)  \n",
    "    if i % 100 == 0:\n",
    "        plot_traj_image(sample, new_obs, target_goal, new_obs, \"/home/m_bobrin/AILOT/notebooks/antmaze-large.png\")\n",
    "\n",
    "    plt.show()\n",
    "    curr_point = new_obs\n",
    "    i+=1\n",
    "save_video.save_video(frames, video_folder='.', fps=30)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Our Algo"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "jupyter": {
     "source_hidden": true
    }
   },
   "outputs": [],
   "source": [
    "from abc import ABC, abstractmethod\n",
    "from jax.numpy import ndarray\n",
    "from ott.geometry import costs\n",
    "from ott.math import utils as mu\n",
    "\n",
    "class OTRewardsExpert:\n",
    "\n",
    "    def __init__(\n",
    "        self, expert_traj,\n",
    "    ):\n",
    "        self.expert_states = expert_traj\n",
    "        self.expert_z = batched_eval_ensemble_psi(icvf_model.value_learner.model, expert_traj).mean(axis=1) # for several expert trajs\n",
    "        self.sub_steps = 1\n",
    "\n",
    "    def make_subs(self, z, sub_steps):\n",
    "        sub_indx = jnp.minimum(jnp.arange(0, z.shape[0]) + sub_steps, z.shape[0] - 1)\n",
    "        return jax.tree_map(lambda arr: arr[sub_indx], z)\n",
    "    \n",
    "\n",
    "    @eqx.filter_jit\n",
    "    def get_z_and_start_index(self, obs):\n",
    "        z = eval_ensemble_psi(icvf_model.value_learner.model, obs).mean(axis=0)\n",
    "        diff = jnp.linalg.norm(z[0][None] - self.expert_z, axis=-1)\n",
    "        i_min = diff.argmin(axis=-1)\n",
    "        return z, i_min, diff\n",
    "\n",
    "    def compute_rewards(\n",
    "        self,\n",
    "        dataset\n",
    "    ):\n",
    "        i0 = 0\n",
    "        rewards = []\n",
    "        observations = dataset.dataset_dict['observations']\n",
    "        episode_starts, episode_ends, episode_returns = gc_agent_dataset.dataset._trajectory_boundaries_and_returns()\n",
    "        \n",
    "        for i1 in tqdm(range(len(episode_starts))):\n",
    "            zi, start_index, diff = self.get_z_and_start_index(observations[episode_starts[i1]:episode_ends[i1]])\n",
    "            ri = self.compute_rewards_one_episode(zi, self.expert_z[:, start_index:])\n",
    "            rewards.append(jax.device_get(ri))\n",
    "                  \n",
    "        return np.concatenate(rewards)\n",
    "\n",
    "    @eqx.filter_jit\n",
    "    def compute_rewards_one_episode(\n",
    "        self, episode_obs, expert_obs\n",
    "    ):\n",
    "\n",
    "        za_1 = episode_obs\n",
    "        za_2 = self.make_subs(za_1, self.sub_steps)\n",
    "        x = jnp.concatenate([za_1, za_2], axis=1)\n",
    "        \n",
    "        ze_1 = expert_obs\n",
    "        ze_2 = self.make_subs(ze_1, self.sub_steps)\n",
    "        y = jnp.concatenate([ze_1, ze_2], axis=1)\n",
    "        \n",
    "        geom = pointcloud.PointCloud(x, y, epsilon=0.001)\n",
    "        ot_prob = linear_problem.LinearProblem(geom)\n",
    "        solver = sinkhorn.Sinkhorn(max_iterations=300, use_danskin=True)\n",
    "\n",
    "        ot_sink = solver(ot_prob)\n",
    "        transp_cost = jnp.sum(ot_sink.matrix * geom.cost_matrix, axis=1)\n",
    "        rewards = -transp_cost * episode_obs.shape[0] / 10\n",
    "        print(rewards.shape)\n",
    "        return rewards"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "@eqx.filter_vmap(in_axes=dict(ensemble=None, s=0))\n",
    "@eqx.filter_vmap(in_axes=dict(ensemble=eqx.if_array(0), s=None))\n",
    "def batched_eval_ensemble_psi(ensemble, s):\n",
    "    return eqx.filter_vmap(ensemble.psi_net)(s)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(4995, 17)"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "expert_trajectory.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "expert_trajectory = expert_trajectory.reshape(5, -1, env.observation_space.shape[0]) # first arg - number of expert trajs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "ExecuteTime": {
     "start_time": "2023-12-27T15:08:01.109Z"
    }
   },
   "outputs": [],
   "source": [
    "from abc import ABC, abstractmethod\n",
    "from jax.numpy import ndarray\n",
    "from ott.geometry import costs\n",
    "from ott.math import utils as mu\n",
    "\n",
    "class OTRewardsExpert:\n",
    "\n",
    "    def __init__(\n",
    "        self, expert_traj,\n",
    "    ):\n",
    "        self.expert_states = expert_traj\n",
    "        self.expert_z = batched_eval_ensemble_psi(icvf_model.value_learner.model, expert_traj).mean(axis=1) # for several expert trajs\n",
    "        self.sub_steps = 1\n",
    "\n",
    "    def make_subs(self, z, sub_steps):\n",
    "        sub_indx = jnp.minimum(jnp.arange(0, z.shape[0]) + sub_steps, z.shape[0] - 1)\n",
    "        return jax.tree_map(lambda arr: arr[sub_indx], z)\n",
    "    \n",
    "\n",
    "    @eqx.filter_jit\n",
    "    def get_z_and_start_index(self, obs):\n",
    "        z = eval_ensemble_psi(icvf_model.value_learner.model, obs).mean(axis=0)\n",
    "        diff = jnp.linalg.norm(z[0][None] - self.expert_z, axis=-1)\n",
    "        i_min = diff.argmin(axis=-1)\n",
    "        return z, i_min, diff\n",
    "\n",
    "    def compute_rewards(\n",
    "        self,\n",
    "        dataset\n",
    "    ):\n",
    "        i0 = 0\n",
    "        rewards = []\n",
    "        observations = dataset.dataset_dict['observations']\n",
    "        episode_starts, episode_ends, episode_returns = gc_agent_dataset.dataset._trajectory_boundaries_and_returns()\n",
    "        \n",
    "        for i1 in tqdm(range(len(episode_starts))):\n",
    "            zi, start_index, diff = self.get_z_and_start_index(observations[episode_starts[i1]:episode_ends[i1]])\n",
    "            expert_demos = []\n",
    "            for i in range(self.expert_z.shape[0]):\n",
    "                traj = self.expert_z[i][start_index[i]:]\n",
    "                expert_demos.append(jnp.pad(traj, ((0, 1000 - traj.shape[0]), (0, 0)), mode='constant', constant_values=0.))\n",
    "            expert_demos = jnp.stack(expert_demos)\n",
    "            ri = self.compute_rewards_one_episode(zi, expert_demos)\n",
    "            rewards.append(jax.device_get(ri))\n",
    "                  \n",
    "        return np.concatenate(rewards)\n",
    "\n",
    "    @eqx.filter_vmap(in_axes=(None, None, 0))\n",
    "    def batched_ot(self, x, y):\n",
    "        geom = pointcloud.PointCloud(x, y, epsilon=0.001)\n",
    "        ot_prob = linear_problem.LinearProblem(geom)\n",
    "        solver = sinkhorn.Sinkhorn(max_iterations=200, use_danskin=True)\n",
    "\n",
    "        ot_sink = solver(ot_prob)\n",
    "        transp_cost = jnp.sum(ot_sink.matrix * geom.cost_matrix, axis=1)\n",
    "        rewards = 3 * jnp.exp(-1000 * 0.1 * 4 * transp_cost)\n",
    "        return rewards\n",
    "        \n",
    "    @eqx.filter_jit\n",
    "    def compute_rewards_one_episode(\n",
    "        self, episode_obs, expert_obs\n",
    "    ):\n",
    "        za_1 = episode_obs\n",
    "        za_2 = self.make_subs(za_1, self.sub_steps)\n",
    "        x = jnp.concatenate([za_1, za_2], axis=-1)\n",
    "        \n",
    "        ze_1 = expert_obs\n",
    "        ze_2 = self.make_subs(ze_1, self.sub_steps)\n",
    "        y = jnp.concatenate([ze_1, ze_2], axis=-1)\n",
    "        rewards = jnp.transpose(self.batched_ot(x, y), axes=(1, 0))\n",
    "        rewards = jnp.max(rewards, axis=-1)\n",
    "        return rewards"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "ExecuteTime": {
     "start_time": "2023-12-27T15:08:01.867Z"
    }
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "aa3508ec32bb4491b8391a58c671eca4",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/202 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_3565121/1579831919.py:17: DeprecationWarning: jax.tree_map is deprecated: use jax.tree.map (jax v0.4.25 or newer) or jax.tree_util.tree_map (any JAX version).\n",
      "  return jax.tree_map(lambda arr: arr[sub_indx], z)\n"
     ]
    }
   ],
   "source": [
    "expert = OTRewardsExpert(expert_trajectory)\n",
    "rewards = expert.compute_rewards(gc_agent_dataset.dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "load datafile:   0%|          | 0/11 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "load datafile: 100%|██████████| 11/11 [00:00<00:00, 48.80it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "91a9470d3876459ba543a8cf1d515739",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/201798 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "## Apply iql scaling\n",
    "from utils.ds_builder import load_trajectories\n",
    "    \n",
    "offline_traj = load_trajectories(env.spec.id, rewards) # scaled_rewards\n",
    "    \n",
    "def compute_iql_reward_scale(trajs):\n",
    "    \"\"\"Rescale rewards based on max/min from the dataset.\n",
    "    This is also used in the original IQL implementation.\n",
    "    \"\"\"\n",
    "    trajs = trajs.copy()\n",
    "    \n",
    "    def compute_returns(tr):\n",
    "        return sum([step[2] for step in tr])\n",
    "    \n",
    "    trajs.sort(key=compute_returns)\n",
    "    reward_scale = 1000.0 / (\n",
    "      compute_returns(trajs[-1]) - compute_returns(trajs[0]))\n",
    "    return reward_scale\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.dataset import Dataset\n",
    "\n",
    "ds = gc_agent_dataset.dataset.dataset_dict\n",
    "episode_starts, episode_ends, episode_returns = gc_agent_dataset.dataset._trajectory_boundaries_and_returns()\n",
    "data_with_ot_rewards = Dataset(\n",
    "    {'observations': np.concatenate([ds['observations'][episode_starts[i]:episode_ends[i]] for i, j in enumerate(episode_starts)]).astype(np.float32),\n",
    "    'next_observations': np.concatenate([ds['next_observations'][episode_starts[i]:episode_ends[i]] for i, j in enumerate(episode_starts)]).astype(np.float32),\n",
    "    'actions': np.concatenate([ds['actions'][episode_starts[i]:episode_ends[i]] for i, j in enumerate(episode_starts)]).astype(np.float32),\n",
    "    'rewards': rewards * compute_iql_reward_scale(offline_traj),\n",
    "    'masks': 1.0 - np.concatenate([ds['dones'][episode_starts[i]:episode_ends[i]] for i, j in enumerate(episode_starts)]).astype(np.float32),\n",
    "    })"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/m_bobrin/anaconda3/envs/jax/lib/python3.10/site-packages/wandb/sdk/lib/ipython.py:77: DeprecationWarning: Importing display from IPython.core.display is deprecated since IPython 7.14, please import from IPython display\n",
      "  from IPython.core.display import HTML, display  # type: ignore\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "Finishing last run (ID:ndnyie3r) before initializing another..."
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "441d02a35c394037a1750bdda7846d33",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "VBox(children=(Label(value='0.000 MB of 0.000 MB uploaded\\r'), FloatProgress(value=1.0, max=1.0)))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "You can sync this run to the cloud by running:<br/><code>wandb sync /home/m_bobrin/AILOT/wandb/offline-run-20240522_203349-ndnyie3r<code>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Find logs at: <code>./wandb/offline-run-20240522_203349-ndnyie3r/logs</code>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Successfully finished last run (ID:ndnyie3r). Initializing new run:<br/>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/m_bobrin/anaconda3/envs/jax/lib/python3.10/site-packages/wandb/sdk/lib/ipython.py:77: DeprecationWarning: Importing display from IPython.core.display is deprecated since IPython 7.14, please import from IPython display\n",
      "  from IPython.core.display import HTML, display  # type: ignore\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "Tracking run with wandb version 0.16.4"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "W&B syncing is set to <code>`offline`<code> in this directory.  <br/>Run <code>`wandb online`<code> or set <code>WANDB_MODE=online<code> to enable cloud syncing."
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<button onClick=\"this.nextSibling.style.display='block';this.style.display='none';\">Display W&B run</button><iframe src='?jupyter=true' style='border:none;width:100%;height:420px;display:none;'></iframe>"
      ],
      "text/plain": [
       "<wandb.sdk.wandb_run.Run at 0x7936bc31a1d0>"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import wandb\n",
    "\n",
    "wandb.init(project='ForPaper', group='AILOT-MultiExpert', name=env.spec.id, mode='offline',\n",
    "          config={'max_steps': 1_000_000, 'seed': 10, 'expectile':0.7, 'discount': 0.99, 'temperature': 3})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {
    "ExecuteTime": {
     "start_time": "2023-12-26T23:19:09.003Z"
    }
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "b4290239a8c5446f9b91f8d603ad5d9e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/1000000 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/m_bobrin/AILOT/notebooks/../src/agents/iql_flax/learner.py:19: DeprecationWarning: jax.tree_map is deprecated: use jax.tree.map (jax v0.4.25 or newer) or jax.tree_util.tree_map (any JAX version).\n",
      "  new_target_params = jax.tree_map(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'return': -562.8501026335916, 'length': 1000.0}\n",
      "{'return': -676.5860698571375, 'length': 1000.0}\n",
      "{'return': -296.1481801392359, 'length': 1000.0}\n",
      "{'return': 1000.9825919874729, 'length': 1000.0}\n",
      "{'return': 1834.2695342235493, 'length': 1000.0}\n",
      "{'return': 1750.5610151939381, 'length': 1000.0}\n",
      "{'return': 2135.2639414001874, 'length': 1000.0}\n",
      "{'return': 2366.086441378525, 'length': 1000.0}\n",
      "{'return': 2100.953889502385, 'length': 1000.0}\n",
      "{'return': 2055.177778884821, 'length': 1000.0}\n",
      "{'return': 3080.5650080537807, 'length': 1000.0}\n",
      "{'return': 2342.505235065336, 'length': 1000.0}\n",
      "{'return': 2473.7866829827913, 'length': 1000.0}\n",
      "{'return': 2022.2883909806428, 'length': 1000.0}\n",
      "{'return': 2657.7558927913806, 'length': 1000.0}\n",
      "{'return': 2413.120794106511, 'length': 1000.0}\n",
      "{'return': 2649.544684034782, 'length': 1000.0}\n",
      "{'return': 2515.5295698579394, 'length': 1000.0}\n",
      "{'return': 2121.062249081803, 'length': 1000.0}\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[24], line 26\u001b[0m\n\u001b[1;32m     18\u001b[0m sample \u001b[38;5;241m=\u001b[39m data_with_ot_rewards\u001b[38;5;241m.\u001b[39msample(batch_size\u001b[38;5;241m=\u001b[39mbatch_size)\n\u001b[1;32m     19\u001b[0m batch \u001b[38;5;241m=\u001b[39m Batch(\n\u001b[1;32m     20\u001b[0m     observations\u001b[38;5;241m=\u001b[39msample[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mobservations\u001b[39m\u001b[38;5;124m\"\u001b[39m],\n\u001b[1;32m     21\u001b[0m     next_observations\u001b[38;5;241m=\u001b[39msample[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mnext_observations\u001b[39m\u001b[38;5;124m\"\u001b[39m],\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m     24\u001b[0m     masks\u001b[38;5;241m=\u001b[39m sample[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmasks\u001b[39m\u001b[38;5;124m\"\u001b[39m]\n\u001b[1;32m     25\u001b[0m )\n\u001b[0;32m---> 26\u001b[0m update_info \u001b[38;5;241m=\u001b[39m \u001b[43miql_agent_ot\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mupdate\u001b[49m\u001b[43m(\u001b[49m\u001b[43mbatch\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     27\u001b[0m update_info[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124madv\u001b[39m\u001b[38;5;124m'\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m \n\u001b[1;32m     28\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m i \u001b[38;5;241m%\u001b[39m \u001b[38;5;241m50_000\u001b[39m \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m0\u001b[39m \u001b[38;5;129;01mand\u001b[39;00m i \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m0\u001b[39m:\n",
      "File \u001b[0;32m~/AILOT/notebooks/../src/agents/iql_flax/learner.py:130\u001b[0m, in \u001b[0;36mLearner.update\u001b[0;34m(self, batch)\u001b[0m\n\u001b[1;32m    129\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mupdate\u001b[39m(\u001b[38;5;28mself\u001b[39m, batch: Batch) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m InfoDict:\n\u001b[0;32m--> 130\u001b[0m     new_rng, new_actor, new_critic, new_value, new_target_critic, info \u001b[38;5;241m=\u001b[39m \u001b[43m_update_jit\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m    131\u001b[0m \u001b[43m        \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrng\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mactor\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcritic\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mvalue\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtarget_critic\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    132\u001b[0m \u001b[43m        \u001b[49m\u001b[43mbatch\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdiscount\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtau\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mexpectile\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtemperature\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    134\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mrng \u001b[38;5;241m=\u001b[39m new_rng\n\u001b[1;32m    135\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mactor \u001b[38;5;241m=\u001b[39m new_actor\n",
      "File \u001b[0;32m<string>:1\u001b[0m, in \u001b[0;36m<lambda>\u001b[0;34m(_cls, count, mu, nu)\u001b[0m\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "from src.agents.iql_flax.common import Batch\n",
    "from src.agents.iql_flax.learner import Learner\n",
    "from src.agents.iql_flax.evaluation import evaluate\n",
    "\n",
    "iql_agent_ot = Learner(\n",
    "        wandb.config.seed,\n",
    "        env.observation_space.sample()[np.newaxis],\n",
    "        env.action_space.sample()[np.newaxis],\n",
    "        max_steps=wandb.config.max_steps,\n",
    "        expectile=wandb.config.expectile,\n",
    "        discount=wandb.config.discount,\n",
    "        temperature=wandb.config.temperature)\n",
    "\n",
    "pbar = tqdm(range(wandb.config.max_steps))\n",
    "expert = OTRewardsExpert(expert_trajectory)\n",
    "\n",
    "for i in pbar:\n",
    "    sample = data_with_ot_rewards.sample(batch_size=batch_size)\n",
    "    batch = Batch(\n",
    "        observations=sample[\"observations\"],\n",
    "        next_observations=sample[\"next_observations\"],\n",
    "        actions = sample['actions'],\n",
    "        rewards= sample[\"rewards\"],\n",
    "        masks= sample[\"masks\"]\n",
    "    )\n",
    "    update_info = iql_agent_ot.update(batch)\n",
    "    update_info['adv'] = None \n",
    "    if i % 50_000 == 0 and i > 0:\n",
    "        eval_stats = evaluate(iql_agent_ot, env, num_episodes=10)\n",
    "        print(eval_stats)\n",
    "        eval_stats['return'] = env.get_normalized_score(eval_stats['return'])*100\n",
    "        wandb.log({f\"Eval/{key}\": value for key, value in eval_stats.items()})\n",
    "        pbar.set_postfix(update_info)\n",
    "    if i % 2000 == 0:\n",
    "        wandb.log({f\"Training/{key}\": value for key, value in update_info.items()})\n",
    "        pbar.set_postfix(update_info)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "84.36980536734823\n",
      "83.19496151175093\n",
      "86.49827431871417\n",
      "84.89019129751595\n",
      "78.6462337610193\n",
      "86.68024646434857\n",
      "79.93628108446032\n",
      "88.0105374728517\n",
      "86.84938015466824\n",
      "86.21227660891945\n",
      "87.4761072626642\n",
      "84.69437106296402\n",
      "84.66783427395559\n",
      "83.63897206648461\n",
      "87.66158863655635\n",
      "82.67445060191794\n",
      "86.27726331492916\n",
      "81.63361660289166\n",
      "81.74831678866148\n",
      "86.14787764592741\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "84.59542931492746"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "means = []\n",
    "for i in range(20):\n",
    "    eval_stats = evaluate(iql_agent_ot, env, num_episodes=10)\n",
    "    means.append(env.get_normalized_score(eval_stats['return'])*100)\n",
    "    print(env.get_normalized_score(eval_stats['return'])*100)\n",
    "    means.append(eval_stats['return'])\n",
    "    wandb.log({\"FinalEval/return\": env.get_normalized_score(eval_stats['return'])*100})\n",
    "np.asarray(means[0::2]).mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "save_video.save_video(frames, video_folder='.', fps=25)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# MUJOCO"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.gc_dataset import GCSDataset\n",
    "from utils.ds_builder import setup_datasets\n",
    "\n",
    "env, expert_ds, agent_ds, agent_mean_states, agent_std_states = setup_datasets(expert_env_name=\"walker2d-medium-replay-v2\",\n",
    "                                          agent_env_name=\"walker2d-medium-replay-v2\", expert_num=1,\n",
    "                                          normalize_agent_states=False)\n",
    "\n",
    "gcsds_params = GCSDataset.get_default_config()\n",
    "gc_expert_dataset = GCSDataset(expert_ds, **gcsds_params)\n",
    "gc_agent_dataset = GCSDataset(agent_ds, **gcsds_params)\n",
    "\n",
    "expert_trajectory = gc_expert_dataset.dataset.dataset_dict['observations'] #c_expert_dataset.get_expert_traj()['observations']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#%cd ..\n",
    "from src.agents import icvf\n",
    "icvf_model = icvf.create_eqx_learner(seed=9804,\n",
    "                                     observations=expert_ds.dataset_dict['observations'][0],\n",
    "                                     hidden_dims=[256, 256],\n",
    "                                     pretrained_folder=\"walker2d-medium-replay\",\n",
    "                                     load_pretrained_icvf=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "## IQL baseline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "wandb.init(project='D4RL-jupyter', group='IQL-base', name=env.spec.id,\n",
    "          config={'max_steps': 1_000_000, 'seed': 9804, 'expectile':0.95, 'discount': 0.99, 'temperature': 3}) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Apply iql scaling\n",
    "from utils.ds_builder import load_trajectories\n",
    "    \n",
    "offline_traj = load_trajectories(\"halfcheetah-medium-v2\", gc_agent_dataset.dataset.dataset_dict['rewards'])\n",
    "    \n",
    "def compute_iql_reward_scale(trajs):\n",
    "    \"\"\"Rescale rewards based on max/min from the dataset.\n",
    "    This is also used in the original IQL implementation.\n",
    "    \"\"\"\n",
    "    trajs = trajs.copy()\n",
    "    \n",
    "    def compute_returns(tr):\n",
    "        return sum([step[2] for step in tr])\n",
    "    \n",
    "    trajs.sort(key=compute_returns)\n",
    "    reward_scale = 1000.0 / (\n",
    "      compute_returns(trajs[-1]) - compute_returns(trajs[0]))\n",
    "    return reward_scale\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "iql_scale = compute_iql_reward_scale(offline_traj)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.agents.iql_flax.common import Batch\n",
    "from src.agents.iql_flax.learner import Learner\n",
    "from src.agents.iql_flax.evaluation import evaluate\n",
    "\n",
    "max_steps = 1_000_000\n",
    "\n",
    "iql_agent = Learner(\n",
    "        9804,\n",
    "        env.observation_space.sample()[np.newaxis],\n",
    "        env.action_space.sample()[np.newaxis],\n",
    "        max_steps=max_steps,\n",
    "        expectile=0.95,\n",
    "        discount=0.99,\n",
    "        temperature=3)\n",
    "\n",
    "pbar = tqdm(range(max_steps))\n",
    "for i in pbar:\n",
    "    sample = gc_agent_dataset.dataset.sample(batch_size=256)\n",
    "    batch = Batch(\n",
    "        observations=sample[\"observations\"],\n",
    "        next_observations=sample[\"next_observations\"],\n",
    "        actions = sample[\"actions\"],\n",
    "        rewards = sample[\"rewards\"] * iql_scale, \n",
    "        masks= sample[\"masks\"]\n",
    "    )\n",
    "    update_info = iql_agent.update(batch)\n",
    "    update_info['adv'] = None\n",
    "    if i % 50_000 == 0 and i > 0:\n",
    "        eval_stats = evaluate(iql_agent, env, num_episodes=10)\n",
    "        print(eval_stats)\n",
    "        eval_stats['return'] = env.get_normalized_score(eval_stats['return'])*100\n",
    "        wandb.log({'Eval': eval_stats})\n",
    "        pbar.set_postfix(update_info)\n",
    "    if i % 3000 == 0:\n",
    "        wandb.log({'Training/': update_info})\n",
    "        pbar.set_postfix(update_info)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from gymnasium.utils import save_video\n",
    "\n",
    "frames=[]\n",
    "i = 0\n",
    "num_episodes = 1 # 1 for render\n",
    "all_reward = []\n",
    "key = jax.random.PRNGKey(42)\n",
    "\n",
    "for i in range(num_episodes):\n",
    "    episode_reward = 0\n",
    "    key, sample_key = jax.random.split(key, 2)\n",
    "    obs = env.reset()\n",
    "    done = False\n",
    "    while not done:\n",
    "        key, sample_key = jax.random.split(sample_key, 2)\n",
    "        action = jax.device_get(iql_agent.sample_actions(obs, temperature=0.0))\n",
    "        obs, reward, done ,_ = env.step(action)\n",
    "        os.environ['CUDA_VISIBLE_DEVICES']='4'\n",
    "        frames.append(env.render(mode='rgb_array'))\n",
    "        os.environ['CUDA_VISIBLE_DEVICES']='0,1,2,3'\n",
    "        episode_reward += reward\n",
    "    all_reward.append(episode_reward)\n",
    "    print(episode_reward)\n",
    "save_video.save_video(frames, video_folder='.', fps=20)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "wandb.finish()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Our Algo"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from abc import ABC, abstractmethod\n",
    "from jax.numpy import ndarray\n",
    "from ott.geometry import costs\n",
    "from ott.math import utils as mu\n",
    "\n",
    "@jax.tree_util.register_pytree_node_class\n",
    "class MyCost(costs.CostFn):\n",
    "    def __init__(self) -> None:\n",
    "        super().__init__()\n",
    "        self.cost = costs.SqEuclidean()\n",
    "\n",
    "    def pairwise(self, x: ndarray, y: ndarray) -> float:\n",
    "        d = self.cost(x, y)\n",
    "        return jnp.minimum(5000, d)\n",
    "\n",
    "class OTRewardsExpert:\n",
    "\n",
    "    def __init__(\n",
    "        self, expert_traj,\n",
    "    ):\n",
    "        self.expert_states = expert_traj\n",
    "        self.expert_z = eval_ensemble_psi(icvf_model.value_learner.model, expert_traj).mean(axis=0)\n",
    "        self.sub_steps = 5\n",
    "\n",
    "    def make_subs(self, z, sub_steps):\n",
    "        sub_indx = jnp.minimum(jnp.arange(0, z.shape[0]) + sub_steps, z.shape[0] - 1)\n",
    "        return jax.tree_map(lambda arr: arr[sub_indx], z)\n",
    "    \n",
    "\n",
    "    @eqx.filter_jit\n",
    "    def get_z_and_start_index(self, obs):\n",
    "        # obs - trajectory\n",
    "        z = eval_ensemble_psi(icvf_model.value_learner.model, obs).mean(axis=0)\n",
    "        diff = jnp.linalg.norm(z[0][jnp.newaxis,] - self.expert_z, axis=-1) #eqx_get_state_traj(icvf_model.value_learner.model, z[0][None], self.expert_z).mean(1)#z[0][jnp.newaxis,] - self.expert_z\n",
    "        i_min = jnp.argmin(diff)#jnp.argmin((diff**2).sum(-1)).squeeze()\n",
    "        return z, i_min, diff\n",
    "\n",
    "    def compute_rewards(\n",
    "        self,\n",
    "        dataset\n",
    "    ):\n",
    "        i0 = 0\n",
    "        rewards = []\n",
    "        observations = dataset.dataset_dict['observations']\n",
    "        episode_starts, episode_ends, episode_returns = gc_agent_dataset.dataset._trajectory_boundaries_and_returns()\n",
    "        \n",
    "        for i1 in tqdm(range(len(episode_starts))):\n",
    "            zi, start_index, diff = self.get_z_and_start_index(observations[episode_starts[i1]:episode_ends[i1]])\n",
    "            ri = self.compute_rewards_one_episode(zi, self.expert_z[start_index:])\n",
    "            #print(eval_ensemble_icvf_latent_zzz(icvf_model.value_learner.model, zi[0][None], self.expert_z[start_index][5][None], self.expert_z[start_index][5][None]).mean(0))\n",
    "            rewards.append(jax.device_get(ri))\n",
    "                  \n",
    "        return np.concatenate(rewards)#, selected_index\n",
    "\n",
    "    @eqx.filter_jit\n",
    "    def compute_rewards_one_episode(\n",
    "        self, episode_obs, expert_obs\n",
    "    ):\n",
    "\n",
    "        za_1 = episode_obs\n",
    "        za_2 = self.make_subs(za_1, self.sub_steps)\n",
    "        x = jnp.concatenate([za_1, za_2], axis=1)\n",
    "\n",
    "        ze_1 = expert_obs\n",
    "        ze_2 = self.make_subs(ze_1, self.sub_steps)\n",
    "        y = jnp.concatenate([ze_1, ze_2], axis=1)\n",
    "        \n",
    "        geom = pointcloud.PointCloud(x, y, epsilon=0.001)\n",
    "        ot_prob = linear_problem.LinearProblem(geom)\n",
    "        solver = sinkhorn.Sinkhorn(max_iterations=200, use_danskin=True)\n",
    "\n",
    "        ot_sink = solver(ot_prob)\n",
    "        transp_cost = jnp.sum(ot_sink.matrix * geom.cost_matrix, axis=1)\n",
    "        rewards = -transp_cost * episode_obs.shape[0] / 10\n",
    "\n",
    "        return rewards\n",
    "        \n",
    "# expert = OTRewardsExpert(expert_trajectory)\n",
    "# rewards = expert.compute_rewards(gc_agent_dataset.dataset)\n",
    "\n",
    "from src.dataset import Dataset\n",
    "\n",
    "class ExpRewardsScaler:\n",
    "    def init(self, rewards: np.ndarray):\n",
    "        self.min = np.quantile(np.abs(rewards).reshape(-1), 0.0)\n",
    "        self.max = np.quantile(np.abs(rewards).reshape(-1), 0.95)\n",
    "\n",
    "    def scale(self, rewards: np.ndarray):\n",
    "        # From paper\n",
    "        return 5* np.exp(5 * rewards)\n",
    "\n",
    "\n",
    "def get_subs(dataset: GCSDataset, add_steps: int):\n",
    "    terminal_locs = dataset.terminal_locs\n",
    "    indx = np.arange(dataset.dataset.dataset_dict['observations'].shape[0])\n",
    "    final_state_indx = terminal_locs[np.searchsorted(terminal_locs, indx)] \n",
    "    way_indx = np.minimum(indx + add_steps, final_state_indx)\n",
    "    subs = jax.tree_map(lambda arr: arr[way_indx], dataset.dataset.dataset_dict['observations'])\n",
    "    return subs\n",
    "\n",
    "scaler = ExpRewardsScaler()\n",
    "scaled_rewards = scaler.scale(rewards).astype(np.float32)\n",
    "\n",
    "## Apply iql scaling\n",
    "from utils.ds_builder import load_trajectories\n",
    "    \n",
    "offline_traj = load_trajectories(env.spec.id, scaled_rewards)\n",
    "    \n",
    "def compute_iql_reward_scale(trajs):\n",
    "    \"\"\"Rescale rewards based on max/min from the dataset.\n",
    "    This is also used in the original IQL implementation.\n",
    "    \"\"\"\n",
    "    trajs = trajs.copy()\n",
    "    \n",
    "    def compute_returns(tr):\n",
    "        return sum([step[2] for step in tr])\n",
    "    \n",
    "    trajs.sort(key=compute_returns)\n",
    "    reward_scale = 1000.0 / (\n",
    "      compute_returns(trajs[-1]) - compute_returns(trajs[0]))\n",
    "    return reward_scale\n",
    "    \n",
    "subs_15 = get_subs(gc_agent_dataset, 15)\n",
    "subs_10 = get_subs(gc_agent_dataset, 10)\n",
    "subs_5 = get_subs(gc_agent_dataset, 5)\n",
    "\n",
    "ds = gc_agent_dataset.dataset.dataset_dict\n",
    "episode_starts, episode_ends, episode_returns = gc_agent_dataset.dataset._trajectory_boundaries_and_returns()\n",
    "data_with_ot_rewards = Dataset(\n",
    "    {'observations': np.concatenate([ds['observations'][episode_starts[i]:episode_ends[i]] for i, j in enumerate(episode_starts)]).astype(np.float32),#[scaled_rewards > r_min],\n",
    "    'next_observations': np.concatenate([ds['next_observations'][episode_starts[i]:episode_ends[i]] for i, j in enumerate(episode_starts)]).astype(np.float32),#[scaled_rewards > r_min],\n",
    "    'actions': np.concatenate([ds['actions'][episode_starts[i]:episode_ends[i]] for i, j in enumerate(episode_starts)]).astype(np.float32),#[scaled_rewards > r_min],\n",
    "    'rewards':scaled_rewards * compute_iql_reward_scale(offline_traj),\n",
    "    'masks': 1.0 - np.concatenate([ds['dones'][episode_starts[i]:episode_ends[i]] for i, j in enumerate(episode_starts)]).astype(np.float32),#[scaled_rewards > r_min],\n",
    "    'sub_observations_5': np.concatenate([subs_5[episode_starts[i]:episode_ends[i]] for i, j in enumerate(episode_starts)]).astype(np.float32),#[scaled_rewards > r_min], \n",
    "    'sub_observations_10':np.concatenate([subs_10[episode_starts[i]:episode_ends[i]] for i, j in enumerate(episode_starts)]).astype(np.float32),#[scaled_rewards > r_min], \n",
    "    'sub_observations_15': np.concatenate([subs_15[episode_starts[i]:episode_ends[i]] for i, j in enumerate(episode_starts)]).astype(np.float32)})#[scaled_rewards > r_min]})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#%wandb offline\n",
    "wandb.init(project='ForPaper', group='AILOT-Mujoco', name=env.spec.id,\n",
    "          config={'max_steps': 1_000_000, 'seed': 5001, 'expectile':0.5, 'discount': 0.99, 'temperature': 1})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.agents.iql_flax.common import Batch\n",
    "from src.agents.iql_flax.learner import Learner\n",
    "from src.agents.iql_flax.evaluation import evaluate\n",
    "\n",
    "iql_agent_ot = Learner(\n",
    "        1337,\n",
    "        env.observation_space.sample()[np.newaxis],\n",
    "        env.action_space.sample()[np.newaxis],\n",
    "        max_steps=wandb.config.max_steps,\n",
    "        expectile=wandb.config.expectile,\n",
    "        discount=wandb.config.discount,\n",
    "        temperature=wandb.config.temperature)\n",
    "\n",
    "pbar = tqdm(range(1_000_000 + 1))\n",
    "expert = OTRewardsExpert(expert_trajectory)\n",
    "\n",
    "for i in pbar:\n",
    "    sample = data_with_ot_rewards.sample(256)\n",
    "    batch = Batch(\n",
    "        observations=sample[\"observations\"],\n",
    "        next_observations=sample[\"next_observations\"],\n",
    "        actions = sample['actions'],\n",
    "        rewards= sample[\"rewards\"],\n",
    "        masks= sample[\"masks\"]\n",
    "    )\n",
    "    update_info = iql_agent_ot.update(batch)\n",
    "    update_info['adv'] = None\n",
    "    if i % 50_000 == 0 and i > 0:\n",
    "        eval_stats = evaluate(iql_agent_ot, env, num_episodes=10)\n",
    "        print(eval_stats)\n",
    "        eval_stats['return'] = env.get_normalized_score(eval_stats['return'])*100\n",
    "        wandb.log({f\"Eval/{key}\": value for key, value in eval_stats.items()})\n",
    "        pbar.set_postfix(update_info)\n",
    "    if i % 2000 == 0:\n",
    "        wandb.log({f\"Training/{key}\": value for key, value in update_info.items()})\n",
    "        pbar.set_postfix(update_info)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "eval_stats = evaluate(iql_agent_ot, env, num_episodes=50)\n",
    "eval_stats['return'] = env.get_normalized_score(eval_stats['return'])*100\n",
    "wandb.log({f\"FinalEval/{key}\": value for key, value in eval_stats.items()})\n",
    "print(eval_stats)\n",
    "wandb.finish()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "wandb.finish()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
