{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This notebook allows to locally learn a model online on a particular game, and then inspect and debug the model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import mediapy\n",
    "\n",
    "import jax\n",
    "import jax.numpy as jnp\n",
    "import jax.random as jr\n",
    "import jax.tree_util as jtu\n",
    "\n",
    "import numpy as np\n",
    "\n",
    "from gameworld.envs import create_gameworld_env\n",
    "\n",
    "from axiom import visualize as vis\n",
    "from axiom import infer as ax\n",
    "from axiom.models import rmm as rmm_tools\n",
    "from axiom.models import imm as imm_tools\n",
    "\n",
    "\n",
    "import defaults\n",
    "from tqdm import tqdm\n",
    "\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import warnings\n",
    "\n",
    "import pickle\n",
    "import rich\n",
    "import os\n",
    "\n",
    "store_path = \"data/models\"\n",
    "if not os.path.exists(store_path):\n",
    "    os.makedirs(store_path)\n",
    "\n",
    "# ignore int64 warnings\n",
    "warnings.filterwarnings(\"ignore\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Set-up\n",
    "\n",
    "Specify the game you want to run on, and specify all hyperparams of the model to evaluate."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "game = \"Explode\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "config = defaults.parse_args(\n",
    "    [\n",
    "        f\"--game={game}\",\n",
    "        \"--num_steps=1000\",\n",
    "        \"--planning_rollouts=128\", # reduce planning rollouts for faster experimentation\n",
    "        # uncomment these lines to run with a \"fixed\" interacting radius\n",
    "        # \"--fixed_r\",\n",
    "        # \"--r_interacting=1.25\",\n",
    "        # \"--r_interacting_predict=0.416\",\n",
    "    ]\n",
    ")\n",
    "\n",
    "rich.print(config)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Experiment\n",
    "\n",
    "Train the agent by running it for the specified number of steps (default 10k). This will generate intermediate reports every 500 steps with the reward curve, visualizing the rMM model, and inspecting the planner of some failed (negative reward) cases."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# reset seed\n",
    "key = jr.PRNGKey(config.seed)\n",
    "np.random.seed(config.seed)\n",
    "\n",
    "# store some data to inspect later\n",
    "inspect = []\n",
    "observations = []\n",
    "nc = []\n",
    "actions = []\n",
    "rewards = []\n",
    "xs = []\n",
    "probs = []\n",
    "tracked = []\n",
    "switches = []\n",
    "rmm_switches = []\n",
    "identities = []\n",
    "used = []\n",
    "moving = []\n",
    "\n",
    "# create env\n",
    "env = create_gameworld_env(config.game, config.simple)\n",
    "\n",
    "# reset\n",
    "obs, _ = env.reset()\n",
    "obs = obs.astype(np.uint8)\n",
    "reward = 0\n",
    "\n",
    "# initialize\n",
    "key, subkey = jr.split(key)\n",
    "carry = ax.init(subkey, config, obs, env.action_space.n)\n",
    "\n",
    "observations.append(obs.astype(np.uint8))\n",
    "rewards.append(reward)\n",
    "actions.append(0)\n",
    "xs.append(carry[\"x\"][config.layer_for_dynamics])\n",
    "tracked.append(carry[\"tracked_obj_ids\"][config.layer_for_dynamics])\n",
    "identity_t = imm_tools.infer_identity(carry[\"imm_model\"], xs[-1][..., None], config.imm.color_only_identity)\n",
    "identities.append(identity_t)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Helper function to investigate a plan.\n",
    "\n",
    "def investigate_plan(carry, xs, tracked, observations, actions, idx, t):\n",
    "    # c = {k: v for k, v in carry.items()}\n",
    "    c = jtu.tree_map(lambda x: x, carry)\n",
    "    c[\"mppi_probs\"] = None\n",
    "    c[\"current_plan\"] = None\n",
    "\n",
    "    c[\"x\"][config.layer_for_dynamics] = xs[idx + t]\n",
    "    c[\"tracked_obj_ids\"][config.layer_for_dynamics] = tracked[idx + t]\n",
    "\n",
    "    key, subkey = jr.split(jr.PRNGKey(0))\n",
    "    action, c, plan_info = ax.plan_fn(subkey, c, config, env.action_space.n)\n",
    "\n",
    "    print(\"predicted vs actual action\")\n",
    "    print(action, actions[idx + t])\n",
    "    best = jnp.argsort(plan_info[\"rewards\"][:, :, 0].sum(0))[-1]\n",
    "    print(\"current plan\")\n",
    "    print(plan_info[\"current_plan\"], plan_info[\"rewards\"][:, :, 0].sum(0)[best])\n",
    "    print(\"prev plan\")\n",
    "    print(plan_info[\"actions\"][:, 0], plan_info[\"rewards\"][:, :, 0].sum(0)[0])\n",
    "    print(plan_info[\"probs\"][0])\n",
    "\n",
    "    mediapy.show_images(\n",
    "        {\n",
    "            \"top-1\": vis.plot_plan(\n",
    "                observations[idx + t],\n",
    "                plan_info,\n",
    "                tracked[idx + t],\n",
    "                carry[\"smm_model\"].stats,\n",
    "                topk=1,\n",
    "            ),\n",
    "            \"prev\": vis.plot_plan(\n",
    "                observations[idx + t],\n",
    "                plan_info,\n",
    "                tracked[idx + t],\n",
    "                carry[\"smm_model\"].stats,\n",
    "                indices=jnp.array([0]),\n",
    "            ),\n",
    "            \"top-20\": vis.plot_plan(\n",
    "                observations[idx + t],\n",
    "                plan_info,\n",
    "                tracked[idx + t],\n",
    "                carry[\"smm_model\"].stats,\n",
    "                topk=20,\n",
    "            ),\n",
    "            \"worst-5\": vis.plot_plan(\n",
    "                observations[idx + t],\n",
    "                plan_info,\n",
    "                tracked[idx + t],\n",
    "                carry[\"smm_model\"].stats,\n",
    "                descending=False,\n",
    "            ),\n",
    "        }\n",
    "    )\n",
    "\n",
    "    x = jnp.asarray(\n",
    "        [xs[i] for i in range(idx + t, idx + t + plan_info[\"states\"].shape[0])]\n",
    "    )\n",
    "    indices = plan_info[\"rewards\"].sum(0)[:, 0].argsort()[-100:]\n",
    "\n",
    "    pred_rewards = plan_info[\"rewards\"].sum(0)[indices, 0]\n",
    "    pred_states = plan_info[\"states\"][:, indices, 0].transpose((1, 0, 2, 3))\n",
    "\n",
    "    mediapy.show_image(\n",
    "        vis.rollout_samples_lineplot(\n",
    "            pred_states,\n",
    "            x,\n",
    "            jnp.argwhere(tracked[idx]).flatten(),\n",
    "            pred_rewards,\n",
    "            plan_info[\"rewards\"][:, indices[:5], 0],\n",
    "        ),\n",
    "        width=800,\n",
    "    )\n",
    "\n",
    "    return plan_info"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bmr_buffer = None, None\n",
    "\n",
    "jax.config.update(\"jax_debug_nans\", False)\n",
    "for t in tqdm(range(config.num_steps)):\n",
    "    # action selection\n",
    "    key, subkey = jr.split(key)\n",
    "    action, carry, plan_info = ax.plan_fn(subkey, carry, config, env.action_space.n)\n",
    "    probs.append(plan_info[\"probs\"])\n",
    "\n",
    "    # step env\n",
    "    obs, reward, done, truncated, info = env.step(action)\n",
    "    obs = obs.astype(np.uint8)\n",
    "\n",
    "    # update models\n",
    "    carry, rec = ax.step_fn(\n",
    "        carry, config, obs, jnp.array(reward), action, num_tracked=0\n",
    "    )\n",
    "\n",
    "    # log stuff\n",
    "    observations.append(obs)\n",
    "    actions.append(action)\n",
    "    rewards.append(reward)\n",
    "    tracked.append(carry[\"tracked_obj_ids\"][config.layer_for_dynamics])\n",
    "    nc.append(carry[\"rmm_model\"].used_mask.sum())\n",
    "\n",
    "    xs.append(carry[\"x\"][config.layer_for_dynamics])\n",
    "    switches.append(rec[\"switches\"])\n",
    "    rmm_switches.append(rec[\"rmm_switches\"])\n",
    "    used.append(carry[\"used\"])\n",
    "    moving.append(carry[\"moving\"])\n",
    "\n",
    "    identity_t = imm_tools.infer_identity(carry[\"imm_model\"], xs[-1][..., None], config.imm.color_only_identity)\n",
    "    identities.append(identity_t)\n",
    "\n",
    "    if done:\n",
    "        obs, _ = env.reset()\n",
    "        obs = obs.astype(np.uint8)\n",
    "        reward = 0\n",
    "        carry, rec = ax.step_fn(\n",
    "            carry,\n",
    "            config,\n",
    "            obs,\n",
    "            jnp.array(reward),\n",
    "            jnp.array(0),\n",
    "            num_tracked=0,\n",
    "            update=False,\n",
    "        )\n",
    "\n",
    "        observations.append(obs)\n",
    "        rewards.append(reward)\n",
    "        actions.append(0)\n",
    "        tracked.append(carry[\"tracked_obj_ids\"][config.layer_for_dynamics])\n",
    "        nc.append(carry[\"rmm_model\"].used_mask.sum())\n",
    "\n",
    "        xs.append(carry[\"x\"][config.layer_for_dynamics])\n",
    "        switches.append(rec[\"switches\"])\n",
    "        rmm_switches.append(rec[\"rmm_switches\"])\n",
    "        probs.append(jnp.ones_like(probs[-1]) / probs[-1].shape[-1])\n",
    "        used.append(carry[\"used\"])\n",
    "        moving.append(carry[\"moving\"])\n",
    "\n",
    "    if (t + 1) % config.prune_every == 0:\n",
    "        key, subkey = jr.split(key)\n",
    "        new_rmm, pairs, *bmr_buffer = ax.reduce_fn_rmm(\n",
    "            subkey, carry[\"rmm_model\"], *bmr_buffer\n",
    "        )\n",
    "        vis.generate_report(\n",
    "            rewards, None, nc, carry[\"rmm_model\"], new_rmm, carry[\"imm_model\"]\n",
    "        )\n",
    "        carry[\"rmm_model\"] = new_rmm\n",
    "\n",
    "    if (t + 1) % 500 == 0:\n",
    "        if config.prune_every >= config.num_steps:\n",
    "            vis.generate_report(\n",
    "                rewards,\n",
    "                None,\n",
    "                nc,\n",
    "                carry[\"rmm_model\"],\n",
    "                carry[\"rmm_model\"],\n",
    "                carry[\"imm_model\"],\n",
    "            )\n",
    "\n",
    "        if jnp.sum(jnp.asarray(rewards[-500:-50]) == -1) > 0:\n",
    "            idx = jnp.argwhere(jnp.asarray(rewards[-500:]) == -1).flatten()[-1] - 20\n",
    "            idx = len(rewards) - 500 + idx\n",
    "\n",
    "            best = jnp.argsort(plan_info[\"rewards\"][:, :, 0].sum(0))[-1]\n",
    "            mediapy.show_videos(\n",
    "                {\n",
    "                    \"last_obs\": observations[-500:],\n",
    "                    \"fail\": observations[idx : idx + 20],\n",
    "                    \"plan\": [\n",
    "                        vis.plot_obs_and_info(\n",
    "                            None, plan_info[\"states\"][t, best, 0]\n",
    "                        )\n",
    "                        for t in range(32)\n",
    "                    ],\n",
    "                }\n",
    "            )\n",
    "            try:\n",
    "                print(\"Investigating index\", idx)\n",
    "                investigate_plan(carry, xs, tracked, observations, actions, idx, 1)\n",
    "            except:\n",
    "                print(\"failed to plot stuff\")\n",
    "        else:\n",
    "            mediapy.show_videos({\"last_obs\": observations[-500:]})"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Results\n",
    "\n",
    "Visualize gameplay of the last 1000 frames."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mediapy.show_videos({game: observations[-1000:]}, fps=40, codec=\"gif\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Visualize the SMM output of the final frame"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "smm_model = carry[\"smm_model\"]\n",
    "rmm_model = carry[\"rmm_model\"]\n",
    "\n",
    "width, height = smm_model.width, smm_model.height\n",
    "stats = smm_model.stats\n",
    "\n",
    "mediapy.show_images(\n",
    "    {\n",
    "        \"qx\": vis.plot_qx_smm(\n",
    "            rec[\"decoded_mu\"][config.layer_for_dynamics],\n",
    "            rec[\"decoded_sigma\"][config.layer_for_dynamics],\n",
    "            stats[\"offset\"],\n",
    "            stats[\"stdevs\"],\n",
    "            width,\n",
    "            height,\n",
    "            rec[\"qz\"][config.layer_for_dynamics],\n",
    "        ),\n",
    "        \"qz\": vis.plot_qz_smm(rec[\"qz\"][config.layer_for_dynamics], width, height),\n",
    "        \"smm_eloglike\": vis.plot_elbo_smm(\n",
    "            rec[\"smm_eloglike\"][config.layer_for_dynamics], width, height\n",
    "        ),\n",
    "    }\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Visualize the discovered \"identities\" by the model. You can also get a sense of which slots a particular object identity occupies during the experiment."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "vis.plot_identity_model(carry[\"imm_model\"], return_ax=True)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Check which slots are used and tracked over time, as well as the inferred identities."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axes = plt.subplots(1, 3, figsize=(12, 4))\n",
    "\n",
    "axes[0].imshow(jnp.stack(tracked), aspect=\"auto\", interpolation=\"none\", cmap=\"gray\")\n",
    "axes[0].set_title(\"Tracked slot timeseries\")\n",
    "axes[1].imshow(\n",
    "    jnp.stack(xs)[:, :, 2] == 0, aspect=\"auto\", interpolation=\"none\", cmap=\"gray\"\n",
    ")\n",
    "axes[1].set_title(\"Used (and hence visible) timeseries\")\n",
    "axes[2].imshow(jnp.stack(identities), aspect=\"auto\", interpolation=\"none\")\n",
    "axes[2].set_title(\"Inferred identity timeseries\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Inspect particular RMM clusters, filtered based on some discrete/continuous observations. Add or modify the filters in the select clause."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Only show used clusters\n",
    "select = carry[\"rmm_model\"].used_mask > 0\n",
    "\n",
    "# Only show clusters with identity 2\n",
    "select = select & (\n",
    "    carry[\"rmm_model\"].model.discrete_likelihoods[0].alpha[:, :, 0].argmax(-1) == 2\n",
    ")\n",
    "\n",
    "# Only show clusters where it's interacting with object identity 0\n",
    "select = select & (\n",
    "    carry[\"rmm_model\"].model.discrete_likelihoods[1].alpha[:, :, 0].argmax(-1) == 0\n",
    ")\n",
    "\n",
    "# Only show clusters where the object dissappears (i.e. has SLDS switch 2)\n",
    "select = select & (\n",
    "    carry[\"rmm_model\"].model.discrete_likelihoods[-1].alpha[:, :, 0].argmax(-1) == 2\n",
    ")\n",
    "\n",
    "\n",
    "vis.plot_rmm(\n",
    "    carry[\"rmm_model\"],\n",
    "    carry[\"imm_model\"],\n",
    "    width=20,\n",
    "    height=20,\n",
    "    colorize=\"cluster\",\n",
    "    indices=jnp.where(select)[0],\n",
    "    return_ax=True,\n",
    ")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Use the plot_hybrid_detail to look into the details of a particular cluster."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "vis.plot_rmm_detail(carry[\"rmm_model\"].model, jnp.argwhere(select).flatten()[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Inspecting particular planner rollouts\n",
    "\n",
    "To further debug, let's look at the last failure case and inspect what the planner predicts to do. First find an index before a failure."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "start_t, end_t = 0, 10000\n",
    "num_steps_before = 20\n",
    "reward_type = -1  # change to 1 to time-lock to a reward\n",
    "\n",
    "idx = (\n",
    "    jnp.argwhere(jnp.asarray(rewards[start_t:end_t]) == reward_type).flatten()[0]\n",
    "    + start_t\n",
    "    - num_steps_before\n",
    ")\n",
    "print(f\"Reward: {reward_type} found at t={idx+num_steps_before}\")\n",
    "mediapy.show_videos(\n",
    "    {\"reward\": observations[idx : idx + num_steps_before]}, codec=\"gif\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Given some time-offset t from the starting idx, inspect the plan."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "t = 0\n",
    "plan_info = investigate_plan(carry, xs, tracked, observations, actions, idx, t)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You can also visualize one particular plan index, e.g. just the best one found."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "best = jnp.argsort(plan_info[\"rewards\"][:, :, 0].sum(0))[-1]\n",
    "\n",
    "mediapy.show_image(\n",
    "    vis.plot_plan(\n",
    "        observations[idx + t],\n",
    "        plan_info,\n",
    "        tracked[idx + t],\n",
    "        carry[\"smm_model\"].stats,\n",
    "        indices=jnp.array([best]),\n",
    "    )\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now for each tracked object we can plot out the planned tMM switches. (i.e. for the first sample of the best planned policy)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for object_idx in jnp.argwhere(tracked[idx+t]).flatten():\n",
    "    print(object_idx, plan_info[\"switches\"][:, best, 0, object_idx])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "as well as the inferred rMM clusters."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for object_idx in jnp.argwhere(tracked[idx + t]).flatten():\n",
    "    print(object_idx, plan_info[\"rmm_switches\"][:, best, 0, object_idx])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Moreover, for a particular rollout, timestep and object_idx you can inspect the inferred rMM cluster in more detail, by calling the predict method directly."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from dataclasses import asdict\n",
    "\n",
    "from axiom.models.rmm import _to_distance_obs_hybrid\n",
    "from axiom.models.rmm import predict\n",
    "\n",
    "policy_idx = best\n",
    "sample_idx = 0\n",
    "object_idx = 1\n",
    "timestep = 1\n",
    "\n",
    "rmm = carry[\"rmm_model\"]\n",
    "imm = carry[\"imm_model\"]\n",
    "# takes the predicted state from the plan at timestep as input\n",
    "x_t = plan_info[\"states\"][timestep, policy_idx, sample_idx]\n",
    "# to predict the next state, we need the previous action in the plan\n",
    "action_t = plan_info[\"actions\"][timestep - 1, policy_idx]\n",
    "\n",
    "tracked_obj_ids = tracked[idx]\n",
    "interact_with_static = False\n",
    "num_switches = config.tmm.n_total_components\n",
    "object_identities = None\n",
    "r_interacting_predict = config.rmm.r_interacting_predict\n",
    "forward_predict = config.rmm.forward_predict\n",
    "stable_r = config.rmm.stable_r\n",
    "reward_prob_threshold = config.rmm.reward_prob_threshold\n",
    "\n",
    "c_obs, d_obs = _to_distance_obs_hybrid(\n",
    "    imm,\n",
    "    x_t,\n",
    "    object_idx,\n",
    "    action_t,\n",
    "    tmm_switch=10,  # we pass in a dummy value (gets overwritten in predict)\n",
    "    reward=jnp.array(0),  # we pass in a dummy value (gets overwritten in predict)\n",
    "    tracked_obj_mask=tracked_obj_ids,\n",
    "    max_switches=config.tmm.n_total_components,\n",
    "    action_dim=rmm.model.discrete_likelihoods[-3].alpha.shape[-2],\n",
    "    object_identities=None,\n",
    "    num_object_classes=rmm.model.discrete_likelihoods[0].alpha.shape[1]\n",
    "    - 1,\n",
    "    **asdict(config.rmm),\n",
    ")\n",
    "c_obs = c_obs[None, :, None]\n",
    "d_obs = jtu.tree_map(lambda d: d[None, :, None], d_obs)\n",
    "\n",
    "# Compute the tMM switching slot using the rMM\n",
    "switch_slot, pred_reward, ell, qz, r_cluster = predict(\n",
    "    rmm,\n",
    "    c_obs,\n",
    "    d_obs,\n",
    "    key=None,\n",
    "    reward_prob_threshold=config.rmm.reward_prob_threshold,\n",
    ")\n",
    "\n",
    "mediapy.show_images(\n",
    "    {\n",
    "        \"imagined\": vis.plot_obs_and_info(\n",
    "            None, plan_info[\"states\"][timestep, policy_idx, sample_idx]\n",
    "        ),\n",
    "        \"rmm_cluster\": vis.plot_rmm(\n",
    "            carry[\"rmm_model\"], carry[\"imm_model\"], indices=jnp.argsort(qz)[-1:], colorize=\"cluster\"\n",
    "        ),\n",
    "    },\n",
    "    width=300,\n",
    ")\n",
    "\n",
    "top5_qz = jnp.argsort(qz)[-5:]\n",
    "print(\"Top 5 qzs\")\n",
    "print(top5_qz)\n",
    "print(qz[top5_qz])\n",
    "\n",
    "for i in top5_qz:\n",
    "    if qz[i] > 0.1:\n",
    "        vis.plot_rmm_detail(\n",
    "            carry[\"rmm_model\"].model,\n",
    "            i,\n",
    "            c_obs=c_obs[0],\n",
    "            d_obs=jtu.tree_map(lambda d: d[0], d_obs),\n",
    "        )\n",
    "\n",
    "print(\"tMM component\")\n",
    "print(carry[\"tmm_model\"].transitions[switch_slot.item()])"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
