{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import jax\n",
    "import jax.numpy as jnp\n",
    "import os\n",
    "os.environ['CUDA_VISIBLE_DEVICES']='2'\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "from IPython import display\n",
    "\n",
    "import ott\n",
    "from ott.geometry import pointcloud\n",
    "from ott.problems.linear import linear_problem\n",
    "from ott.solvers.linear import sinkhorn\n",
    "from ott.tools import plot, sinkhorn_divergence\n",
    "from ott.tools.sinkhorn_divergence import SinkhornDivergenceOutput\n",
    "from ott.solvers.linear import implicit_differentiation as imp_diff\n",
    "import equinox as eqx"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [],
   "source": [
    "def sink_div(combined, states, y, b, key) -> tuple[float, float]:\n",
    "    agent_value, agent_policy = combined\n",
    "    z_dist = eqx.filter_vmap(agent_policy)(states)\n",
    "    z, log_prob = z_dist.sample_and_log_prob(seed=key) # intentions of agent\n",
    "    geom = pointcloud.PointCloud(z, y)\n",
    "    \n",
    "    a = eqx.filter_vmap(agent_value)(states, z).squeeze() # weights for intents of agent\n",
    "    an = jax.nn.softplus(a - jnp.quantile(a, 0.01)) \n",
    "    bn = jax.nn.softplus(b - jnp.quantile(b, 0.01))\n",
    "        \n",
    "\n",
    "    an = an / an.sum()\n",
    "    bn = bn / bn.sum()\n",
    "    ot = sinkhorn_divergence.sinkhorn_divergence(\n",
    "        geom,\n",
    "        x=geom.x,\n",
    "        a=an,\n",
    "        b=bn,\n",
    "        y=geom.y,\n",
    "        static_b=True,\n",
    "        sinkhorn_kwargs={\n",
    "            \"implicit_diff\": imp_diff.ImplicitDiff(),\n",
    "            \"use_danskin\": True,\n",
    "            \"max_iterations\": 2000\n",
    "        },\n",
    "    )\n",
    "    return ot.divergence, (-log_prob.squeeze()).min()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {},
   "outputs": [],
   "source": [
    "import equinox as eqx\n",
    "import equinox.nn as eqxnn\n",
    "\n",
    "class MonolithicVF_EQX(eqx.Module):\n",
    "    net: eqx.Module\n",
    "    \n",
    "    def __init__(self, key, state_dim, intents_dim, hidden_dims):\n",
    "        key, mlp_key = jax.random.split(key, 2)\n",
    "        self.net = eqxnn.MLP(\n",
    "            in_size=state_dim + intents_dim, out_size=1, width_size=hidden_dims[-1], depth=len(hidden_dims), key=mlp_key\n",
    "        )\n",
    "        \n",
    "    def __call__(self, observations, intents):\n",
    "        # TODO: Maybe try FiLM conditioning like in SAC-RND?\n",
    "        conditioning = jnp.concatenate([observations, intents], axis=-1)\n",
    "        return self.net(conditioning)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(250,)\n"
     ]
    }
   ],
   "source": [
    "from jaxrl_m.common import TrainStateEQX\n",
    "from src.agents.iql_equinox import GaussianPolicy, GaussianIntentPolicy\n",
    "import optax\n",
    "\n",
    "key = jax.random.PRNGKey(42)\n",
    "actor_intents_learner = TrainStateEQX.create(model=GaussianIntentPolicy(key=key,\n",
    "                             hidden_dims=[128, 128, 128],\n",
    "                             state_dim=29,\n",
    "                             intent_dim=2), optim=optax.adam(learning_rate=3e-4))\n",
    "\n",
    "x = jax.random.normal(key, (250, 29)) \n",
    "z_dist = eqx.filter_vmap(actor_intents_learner.model)(x)\n",
    "z, log_prob = z_dist.sample_and_log_prob(seed=key)\n",
    "\n",
    "print(log_prob.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {},
   "outputs": [],
   "source": [
    "def gradient_flow(\n",
    "    x: jnp.ndarray,\n",
    "    y: jnp.ndarray,\n",
    "    b,\n",
    "    cost_fn: callable,\n",
    "    num_iter: int = 6000,\n",
    "    dump_every: int = 50\n",
    "):\n",
    "    def v_loss(agent_policy, agent_value, states, key) -> float:\n",
    "        z_dist = eqx.filter_vmap(agent_policy)(states)\n",
    "        z, _ = z_dist.sample_and_log_prob(seed=key)\n",
    "        v = eqx.filter_vmap(agent_value)(states, z).squeeze()\n",
    "        return -v.mean() * 0.1\n",
    "\n",
    "    cost_fn_vg = eqx.filter_jit(eqx.filter_value_and_grad(cost_fn, has_aux=True))\n",
    "    v_loss_vg = eqx.filter_jit(eqx.filter_value_and_grad(v_loss, has_aux=False))\n",
    "\n",
    "    key = jax.random.PRNGKey(42)\n",
    "    V = TrainStateEQX.create(\n",
    "        model=MonolithicVF_EQX(key, 29, 256, [128, 128, 128]),\n",
    "        optim=optax.adam(learning_rate=3e-4)\n",
    "    )\n",
    "\n",
    "    key, pkey = jax.random.split(key, 2)\n",
    "\n",
    "    actor_intents_learner = TrainStateEQX.create(model=GaussianIntentPolicy(key=key,\n",
    "                             hidden_dims=[128, 128, 128],\n",
    "                             state_dim=29,\n",
    "                             intent_dim=256), optim=optax.adam(learning_rate=3e-4))\n",
    "    \n",
    "    for i in range(0, num_iter + 1):\n",
    "        key, key_6 = jax.random.split(key, 2)\n",
    "\n",
    "        (cost, pmin), (value_grads, policy_grads) = cost_fn_vg((V.model, actor_intents_learner.model), x, y, b, key_6)\n",
    "        v_loss, policy_grads_2 = v_loss_vg(actor_intents_learner.model, V.model, x, key_6)\n",
    "        V = V.apply_updates(value_grads)\n",
    "        policy_grads = jax.tree_map(lambda g1, g2: g1 + g2, policy_grads, policy_grads_2)\n",
    "        actor_intents_learner = actor_intents_learner.apply_updates(policy_grads)\n",
    "\n",
    "        if i % dump_every == 0:\n",
    "            z = eqx.filter_vmap(actor_intents_learner.model)(x).sample(seed=key_6)\n",
    "            a = eqx.filter_vmap(V.model)(x, z).squeeze()\n",
    "            \n",
    "            an = jax.nn.softplus(a - jnp.quantile(a, 0.01))\n",
    "            bn = jax.nn.softplus(b - jnp.quantile(b, 0.01))\n",
    "            an = an / an.sum()\n",
    "            bn = bn / bn.sum()\n",
    "\n",
    "            geom = pointcloud.PointCloud(z, y, epsilon=0.01)\n",
    "            diff = sinkhorn.Sinkhorn()(linear_problem.LinearProblem(geom, a = an, b = bn)).reg_ot_cost\n",
    "            print(cost, diff, pmin)\n",
    "            print()\n",
    "\n",
    "    return policy.model, V.model\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 73,
   "metadata": {},
   "outputs": [],
   "source": [
    "def sink_div(agent_policy, states, y, b, key) -> tuple[float, float]:\n",
    "    z_dist = eqx.filter_vmap(agent_policy)(states)\n",
    "    z, log_prob = z_dist.sample_and_log_prob(seed=key) # intentions of agent\n",
    "    geom = pointcloud.PointCloud(z, y)\n",
    "    ot = sinkhorn_divergence.sinkhorn_divergence(\n",
    "        geom,\n",
    "        x=geom.x,\n",
    "        y=geom.y,\n",
    "        epsilon=0.1,\n",
    "        static_b=True,\n",
    "        sinkhorn_kwargs={\n",
    "            \"implicit_diff\": imp_diff.ImplicitDiff(),\n",
    "            \"use_danskin\": True,\n",
    "            \"max_iterations\": 2000\n",
    "        },\n",
    "    )\n",
    "    return ot.divergence\n",
    "    \n",
    "def gradient_flow(\n",
    "    x: jnp.ndarray,\n",
    "    y: jnp.ndarray,\n",
    "    b,\n",
    "    cost_fn: callable,\n",
    "    num_iter: int = 6000,\n",
    "    dump_every: int = 50\n",
    "):\n",
    "\n",
    "    cost_fn_vg = eqx.filter_jit(eqx.filter_value_and_grad(cost_fn))\n",
    "    key = jax.random.PRNGKey(42)\n",
    "    key, pkey = jax.random.split(key, 2)\n",
    "\n",
    "    actor_intents_learner = TrainStateEQX.create(model=GaussianIntentPolicy(key=key,\n",
    "                             hidden_dims=[128, 128, 128],\n",
    "                             state_dim=29,\n",
    "                             intent_dim=256), optim=optax.adam(learning_rate=3e-4))\n",
    "    \n",
    "    for i in range(0, num_iter + 1):\n",
    "        key, key_6 = jax.random.split(key, 2)\n",
    "\n",
    "        cost, policy_grads = cost_fn_vg(actor_intents_learner.model, x, y, b, key_6)\n",
    "        actor_intents_learner = actor_intents_learner.apply_updates(policy_grads)\n",
    "\n",
    "        if i % dump_every == 0:\n",
    "            z = eqx.filter_vmap(actor_intents_learner.model)(x).sample(seed=key_6)\n",
    "            geom = pointcloud.PointCloud(z, y, epsilon=0.01)\n",
    "            diff = sinkhorn.Sinkhorn()(linear_problem.LinearProblem(geom)).reg_ot_cost\n",
    "            print(cost, diff)\n",
    "            print()\n",
    "\n",
    "    return policy.model\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 74,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "435.81137 434.6088\n",
      "\n",
      "425.7082 424.34558\n",
      "\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[74], line 7\u001b[0m\n\u001b[1;32m      4\u001b[0m x \u001b[38;5;241m=\u001b[39m jax\u001b[38;5;241m.\u001b[39mrandom\u001b[38;5;241m.\u001b[39mnormal(key1, (\u001b[38;5;241m100\u001b[39m, \u001b[38;5;241m29\u001b[39m))\n\u001b[1;32m      5\u001b[0m y \u001b[38;5;241m=\u001b[39m jax\u001b[38;5;241m.\u001b[39mrandom\u001b[38;5;241m.\u001b[39mnormal(key2, (\u001b[38;5;241m400\u001b[39m, \u001b[38;5;241m256\u001b[39m))\n\u001b[0;32m----> 7\u001b[0m policy_model, V_model \u001b[38;5;241m=\u001b[39m \u001b[43mgradient_flow\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43my\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmarginal_b\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcost_fn\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43msink_div\u001b[49m\u001b[43m)\u001b[49m\n",
      "Cell \u001b[0;32mIn[73], line 40\u001b[0m, in \u001b[0;36mgradient_flow\u001b[0;34m(x, y, b, cost_fn, num_iter, dump_every)\u001b[0m\n\u001b[1;32m     37\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m i \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(\u001b[38;5;241m0\u001b[39m, num_iter \u001b[38;5;241m+\u001b[39m \u001b[38;5;241m1\u001b[39m):\n\u001b[1;32m     38\u001b[0m     key, key_6 \u001b[38;5;241m=\u001b[39m jax\u001b[38;5;241m.\u001b[39mrandom\u001b[38;5;241m.\u001b[39msplit(key, \u001b[38;5;241m2\u001b[39m)\n\u001b[0;32m---> 40\u001b[0m     cost, policy_grads \u001b[38;5;241m=\u001b[39m \u001b[43mcost_fn_vg\u001b[49m\u001b[43m(\u001b[49m\u001b[43mactor_intents_learner\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43my\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mb\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkey_6\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     41\u001b[0m     actor_intents_learner \u001b[38;5;241m=\u001b[39m actor_intents_learner\u001b[38;5;241m.\u001b[39mapply_updates(policy_grads)\n\u001b[1;32m     43\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m i \u001b[38;5;241m%\u001b[39m dump_every \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m0\u001b[39m:\n",
      "    \u001b[0;31m[... skipping hidden 3 frame]\u001b[0m\n",
      "File \u001b[0;32m~/anaconda3/envs/icvf/lib/python3.9/site-packages/equinox/_module.py:732\u001b[0m, in \u001b[0;36m_unflatten_module\u001b[0;34m(cls, aux, dynamic_field_values)\u001b[0m\n\u001b[1;32m    722\u001b[0m     aux \u001b[38;5;241m=\u001b[39m _FlattenedData(\n\u001b[1;32m    723\u001b[0m         \u001b[38;5;28mtuple\u001b[39m(dynamic_field_names),\n\u001b[1;32m    724\u001b[0m         \u001b[38;5;28mtuple\u001b[39m(static_field_names),\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m    727\u001b[0m         \u001b[38;5;28mtuple\u001b[39m(wrapper_field_values),\n\u001b[1;32m    728\u001b[0m     )\n\u001b[1;32m    729\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mtuple\u001b[39m(dynamic_field_values), aux\n\u001b[0;32m--> 732\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_unflatten_module\u001b[39m(\u001b[38;5;28mcls\u001b[39m: \u001b[38;5;28mtype\u001b[39m[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mModule\u001b[39m\u001b[38;5;124m\"\u001b[39m], aux: _FlattenedData, dynamic_field_values):\n\u001b[1;32m    733\u001b[0m     \u001b[38;5;66;03m# This doesn't go via `__init__`. A user may have done something nontrivial there,\u001b[39;00m\n\u001b[1;32m    734\u001b[0m     \u001b[38;5;66;03m# and the field values may be dummy values as used in various places throughout JAX.\u001b[39;00m\n\u001b[1;32m    735\u001b[0m     \u001b[38;5;66;03m# See also\u001b[39;00m\n\u001b[1;32m    736\u001b[0m     \u001b[38;5;66;03m# https://jax.readthedocs.io/en/latest/pytrees.html#custom-pytrees-and-initialization,\u001b[39;00m\n\u001b[1;32m    737\u001b[0m     \u001b[38;5;66;03m# which was (I believe) inspired by Equinox's approach here.\u001b[39;00m\n\u001b[1;32m    738\u001b[0m     module \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mobject\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;21m__new__\u001b[39m(\u001b[38;5;28mcls\u001b[39m)\n\u001b[1;32m    739\u001b[0m     \u001b[38;5;28;01mfor\u001b[39;00m name, value \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mzip\u001b[39m(aux\u001b[38;5;241m.\u001b[39mdynamic_field_names, dynamic_field_values):\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "key1, key2 = jax.random.split(jax.random.PRNGKey(0), 2)\n",
    "\n",
    "\n",
    "x = jax.random.normal(key1, (100, 29))\n",
    "y = jax.random.normal(key2, (400, 256))\n",
    "\n",
    "policy_model, V_model = gradient_flow(x, y, marginal_b, cost_fn=sink_div)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
