{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# MuZero Visualization Analysis\n",
    "This code provides a visualization analysis of the MountainCar environment's observation space and the latent states of a trained MuZero model using PCA (Principal Component Analysis) according to paper [Visualizing MuZero Models](http://arxiv.org/abs/2102.12924). The goal is to gain insights into the specific types of representations learned by these models.\n",
    "\n",
    "The code performs the following steps:\n",
    "\n",
    "1. Import the required libraries and modules.\n",
    "2. Load the MountainCar environment and its observation space.\n",
    "3. Load a pre-trained MuZero model.\n",
    "4. Generate random observation samples (observations) from the MountainCar environment.\n",
    "5. Perform PCA to reduce the dimensionality of the original observation space.\n",
    "6. Extract the latent states from the MuZero model by inputting the observation samples into the model's representation network and extracting the output.\n",
    "7. Apply PCA to reduce the dimensionality of the latent states.\n",
    "\n",
    "Visualize the reduced observation space and latent states.\n",
    "By conducting these visualization analyses, we can gain a deeper understanding of how the MuZero model learns and represents information in the MountainCar environment.\n",
    "\n",
    "For more information about the mountain_car environment, see [mountain_car doc](https://www.gymlibrary.dev/environments/classic_control/mountain_car/)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from functools import partial\n",
    "from typing import Optional, Tuple\n",
    "\n",
    "import numpy as np\n",
    "import torch\n",
    "from tensorboardX import SummaryWriter\n",
    "from sklearn.decomposition import PCA\n",
    "import matplotlib.pyplot as plt\n",
    "import plotly.graph_objects as go\n",
    "\n",
    "from ding.config import compile_config\n",
    "from ding.envs import create_env_manager\n",
    "from ding.envs import get_vec_env_setting\n",
    "from ding.policy import create_policy\n",
    "from ding.utils import set_pkg_seed\n",
    "from ding.torch_utils import to_tensor, to_device, to_ndarray\n",
    "from ding.worker import BaseLearner\n",
    "from lzero.worker import MuZeroEvaluator\n",
    "from lzero.policy import InverseScalarTransform, mz_network_output_unpack\n",
    "\n",
    "from zoo.classic_control.mountain_car.config.mtcar_muzero_config import main_config, create_config\n",
    "# from lzero.entry import eval_muzero\n",
    "import numpy as np\n",
    "\n",
    "from typing import Optional, Tuple, List"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load Model\n",
    "This code segment loads a pre-trained MuZero model and performs evaluation on the MountainCar environment.\n",
    "\n",
    "It sets up the necessary configurations, components, and dependencies for the evaluation process.\n",
    "The MuZero model is loaded from the specified model path and the evaluation is performed using the MuZeroEvaluator.\n",
    "The evaluation results, including trajectories and returns, are stored for further analysis and visualization.\n",
    "\n",
    "This code provides a convenient way to evaluate the performance of a trained MuZero model in the MountainCar environment."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_path = \"your_path/mountain_car_muzero_seed0/ckpt/ckpt_best.pth.tar\"\n",
    "returns_mean_seeds = []\n",
    "returns_seeds = []\n",
    "seed = 0\n",
    "num_episodes_each_seed = 1\n",
    "total_test_episodes = num_episodes_each_seed\n",
    "create_config.env_manager.type = 'base'  # Visualization requires the 'type' to be set as base\n",
    "main_config.env.evaluator_env_num = 1  # Visualization requires the 'env_num' to be set as 1\n",
    "main_config.env.n_evaluator_episode = total_test_episodes\n",
    "main_config.env.replay_path = 'lz_result/video/mtcar_mz'\n",
    "main_config.exp_name = f'lz_result/eval/muzero_eval_ls{main_config.policy.model.latent_state_dim}'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cfg, create_cfg = main_config, create_config\n",
    "assert create_cfg.policy.type in ['efficientzero', 'muzero', 'stochastic_muzero', 'gumbel_muzero', 'sampled_efficientzero'], \\\n",
    "    \"LightZero now only support the following algo.: 'efficientzero', 'muzero', 'stochastic_muzero', 'gumbel_muzero', 'sampled_efficientzero'\"\n",
    "\n",
    "if cfg.policy.cuda and torch.cuda.is_available():\n",
    "    cfg.policy.device = 'cuda'\n",
    "else:\n",
    "    cfg.policy.device = 'cpu'\n",
    "\n",
    "cfg = compile_config(cfg, seed=seed, env=None, auto=True, create_cfg=create_cfg, save_cfg=True)\n",
    "# Create main components: env, policy\n",
    "env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env)\n",
    "evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg])\n",
    "\n",
    "evaluator_env.seed(cfg.seed, dynamic_seed=False)\n",
    "set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)\n",
    "\n",
    "policy = create_policy(cfg.policy, model=None, enable_field=['learn', 'collect', 'eval'])\n",
    "\n",
    "# load pretrained model\n",
    "if model_path is not None:\n",
    "    policy.learn_mode.load_state_dict(torch.load(model_path, map_location=cfg.policy.device))\n",
    "\n",
    "# Create worker components: learner, collector, evaluator, replay buffer, commander.\n",
    "tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial'))\n",
    "learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)\n",
    "\n",
    "# ==============================================================\n",
    "# MCTS+RL algorithms related core code\n",
    "# ==============================================================\n",
    "policy_config = cfg.policy\n",
    "evaluator = MuZeroEvaluator(\n",
    "    eval_freq=cfg.policy.eval_freq,\n",
    "    n_evaluator_episode=cfg.env.n_evaluator_episode,\n",
    "    stop_value=cfg.env.stop_value,\n",
    "    env=evaluator_env,\n",
    "    policy=policy.eval_mode,\n",
    "    tb_logger=tb_logger,\n",
    "    exp_name=cfg.exp_name,\n",
    "    policy_config=policy_config\n",
    ")\n",
    "\n",
    "# ==========\n",
    "# Main loop\n",
    "# ==========\n",
    "# Learner's before_run hook.\n",
    "learner.call_hook('before_run')\n",
    "\n",
    "# ==============================================================\n",
    "# eval trained model\n",
    "# ==============================================================\n",
    "stop_flag, episode_info = evaluator.eval(learner.save_checkpoint, learner.train_iter, return_trajectory=True)\n",
    "trajectorys = episode_info['trajectory']\n",
    "returns = episode_info['eval_episode_return']\n",
    "returns = np.array(returns)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Original space enumeration\n",
    "\n",
    "This code segment provides functions to work with the original state space of an environment.\n",
    "\n",
    "The `create_grid` function creates a grid in the original state space based on the specified resolution.\n",
    "\n",
    "The `get_state_space` function generates the state space grid using the observation space of the environment.\n",
    "\n",
    "The `embedding_manifold` function computes the latent states, values, and policy logits for the given state space using a trained model.\n",
    "\n",
    "The code then applies these functions to the MountainCar environment, printing the shapes of the original state space and the corresponding latent state space.\n",
    "This code facilitates the exploration and analysis of the original state space and its embeddings using a trained model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_grid(v_mins: List, v_maxs: List, resolution: int) -> np.ndarray:\n",
    "    data = list(map(lambda r: np.linspace(*r, resolution), zip(v_mins, v_maxs)))\n",
    "    grid = np.asarray(np.meshgrid(*data, indexing=\"ij\")).T.reshape(-1, len(v_mins))\n",
    "    return grid\n",
    "\n",
    "def get_state_space(env, resolution: int = 25) -> np.ndarray:\n",
    "    obs_space = env.observation_space\n",
    "    state_space = create_grid(obs_space.low, obs_space.high, resolution)\n",
    "    return state_space\n",
    "\n",
    "\n",
    "def embedding_manifold(state_space, model, return_pis: bool = False, policy_cfg = None) -> Tuple:\n",
    "    with torch.no_grad():\n",
    "        network_output = model.initial_inference(state_space)\n",
    "    latent_state, reward, value, policy_logits = mz_network_output_unpack(network_output)\n",
    "    inverse_scalar_transform_handler = InverseScalarTransform(\n",
    "        policy_cfg.model.support_scale,\n",
    "        policy_cfg.device,\n",
    "        policy_cfg.model.categorical_distribution)\n",
    "    value_real = inverse_scalar_transform_handler(value)\n",
    "\n",
    "    if return_pis:\n",
    "        return to_ndarray(latent_state.cpu()), to_ndarray(value_real.cpu()), to_ndarray(policy_logits.cpu())\n",
    "    \n",
    "    return to_ndarray(latent_state.cpu()), to_ndarray(value_real.cpu())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "delta = 250\n",
    "state_space = get_state_space(evaluator_env, delta)\n",
    "state_space_tensor = to_device(to_tensor(state_space), policy_config.device)\n",
    "latent_state_space, v_state_space = embedding_manifold(state_space_tensor, policy._model, policy_cfg=policy_config)\n",
    "print(state_space.shape, latent_state_space.shape)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## PCA\n",
    "\n",
    "This code segment provides functions related to Principal Component Analysis (PCA) for latent states.\n",
    "\n",
    "The `embedding_PCA` function performs PCA on the given latent states array.\n",
    "It accepts an optional `standardize` parameter to control whether the data should be standardized before performing PCA.\n",
    "The function computes the principal components and the explained variance ratio.\n",
    "It then creates a bar chart to visualize the explained variance ratio for each principal component.\n",
    "Additionally, it creates a violin plot to show the distribution of the projected values.\n",
    "The function returns the PCA object.\n",
    "\n",
    "The code applies the `embedding_PCA` function to the `latent_state_space` array twice, once with standardization disabled and once with it enabled.\n",
    "Finally, it transforms the `latent_state_space` using the computed PCA for further analysis or visualization.\n",
    "\n",
    "This code provides a convenient way to perform PCA on latent states and visualize the results."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def embedding_PCA(latent_states: np.ndarray, standardize: bool = False):   \n",
    "    x = latent_states\n",
    "    if standardize:\n",
    "        x = (x - x.mean(axis=0)) / x.std(axis=0)\n",
    "    \n",
    "    # Perform PCA on latent dimensions\n",
    "    pca = PCA(n_components=x.shape[-1])\n",
    "    pca.fit(x)\n",
    "    spcs = pca.fit_transform(x)\n",
    "    \n",
    "    # Create barchart\n",
    "    ns = list(range(x.shape[-1]))\n",
    "    var = pca.explained_variance_ratio_\n",
    "    \n",
    "    bar = plt.bar(ns, var)\n",
    "    \n",
    "    plt.title(f\"PCA on latent-states (standardize={standardize})\")\n",
    "    plt.ylabel(\"Explained Variance Ratio\")\n",
    "    plt.xlabel(\"Principal Component\")\n",
    "    \n",
    "    \n",
    "    for i in range(len(var)):\n",
    "        plt.annotate(f'{var[i]:.3f}', xy=(ns[i],var[i]), ha='center', va='bottom')\n",
    "\n",
    "    plt.show()\n",
    "    \n",
    "    # Create violinplot\n",
    "    plt.violinplot(spcs)\n",
    "    plt.xticks(range(1, x.shape[-1]+1), range(1, x.shape[-1]+1))\n",
    "    \n",
    "    plt.title(f\"Projected values distribution (standardize={standardize})\")\n",
    "    plt.ylabel(\"PC values\")\n",
    "    plt.xlabel(\"Principal Component\")\n",
    "    \n",
    "    plt.show()\n",
    "    \n",
    "    return pca"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pca = embedding_PCA(latent_state_space, False)\n",
    "pca_norm = embedding_PCA(latent_state_space, True)\n",
    "pca_latent_state_space = pca.transform(latent_state_space)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Original space/PCA latent state visualization\n",
    "\n",
    "This code segment provides functions for visualizing the state space.\n",
    "\n",
    "The `to_grid` function reshapes the input array into a grid with a specified resolution.\n",
    "\n",
    "The `simple_PC_value_contour` function creates a scatter plot to visualize the PCA-transformed latent states.\n",
    "It takes the first two PCA components (`pc_1` and `pc_2`) along with the corresponding values (`z`) as inputs.\n",
    "The scatter points are colored based on the values (`z`) and a colorbar is added to indicate the value scale.\n",
    "\n",
    "The `simple_MC_value_contour` function creates a contour plot to visualize the original state space of the MountainCar environment.\n",
    "It takes the position values (`x`), velocity values (`y`), and corresponding values (`z`) as inputs.\n",
    "The contour levels represent the values (`z`), and a colorbar is added to show the value scale.\n",
    "\n",
    "The code applies these visualization functions to the PCA-transformed latent state space (`pca_latent_state_space`) and the original state space (`state_space`).\n",
    "The resulting plots provide visual representations of the value distribution in the latent space and the original state space."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "# State space visualization\n",
    "def to_grid(x: np.ndarray, delta: int) -> np.ndarray:\n",
    "    return x.reshape(delta, delta)\n",
    "\n",
    "\n",
    "def simple_PC_value_contour(pc_1: np.ndarray, pc_2: np.ndarray, z: np.ndarray) -> None:\n",
    "    # Draw the latent state after PCA\n",
    "    plt.scatter(pc_1, pc_2, c=z, alpha=0.5, s=5, cmap='rainbow')\n",
    "\n",
    "    cbar = plt.colorbar()\n",
    "    cbar.set_label(r'$V_\\theta(o_t)$')\n",
    "\n",
    "    plt.title(\"Value Contour MuZero PC-Space\")\n",
    "    plt.ylabel(r\"First PCA component $h_\\theta(o_t)$\")\n",
    "    plt.xlabel(r\"Second PCA component $h_\\theta(o_t)$\")\n",
    "\n",
    "\n",
    "def simple_MC_value_contour(x: np.ndarray, y: np.ndarray, z: np.ndarray) -> None:\n",
    "    # Draw the original state\n",
    "    # Simple Example Figure for a 2-d env\n",
    "    plt.title(\"Value Contour MuZero MountainCar\")\n",
    "    plt.ylabel(\"Velocity\")\n",
    "    plt.xlabel(\"Position\")\n",
    "\n",
    "    plt.contourf(x, y, z, levels=100, cmap='rainbow')\n",
    "\n",
    "    cbar = plt.colorbar()\n",
    "    cbar.set_label(r'$V_\\theta(o_t)$')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "simple_PC_value_contour(pca_latent_state_space[:, 0], pca_latent_state_space[:, 1], v_state_space)\n",
    "plt.show()\n",
    "simple_MC_value_contour(to_grid(state_space[:,0], delta), to_grid(state_space[:,1], delta), to_grid(v_state_space, delta))\n",
    "plt.show()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Get the trajectory obtained by eval\n",
    "\n",
    "This code segment involves processing game trajectories and obtaining latent state dynamics.\n",
    "\n",
    "The `get_latent_trajectory` function takes embeddings, actions, and a model as inputs and returns the latent state trajectory.\n",
    "It initializes the latent state with the first embedding, and then iteratively computes the latent state using the model's recurrent inference.\n",
    "The resulting latent states are stored in a list and concatenated to form the stacked latent state trajectory.\n",
    "\n",
    "The code then loads a real state trajectory (`real_state`) and corresponding actions (`actions`).\n",
    "These trajectories are converted to tensors and processed using the policy model.\n",
    "The resulting latent state representations (`latent_state_represent`) and value trajectory (`v_trajectorys`) are converted to NumPy arrays.\n",
    "\n",
    "Next, the code calls the `get_latent_trajectory` function to obtain the latent state dynamics based on the latent state representations and actions.\n",
    "The obtained latent state dynamics are then projected to the PC-space using the precomputed PCA object (`pca`).\n",
    "The resulting PC-space trajectories are stored in `pc_embedding_trajectory` and `pc_dynamics_trajectory`.\n",
    "\n",
    "This code provides a way to process game trajectories, extract latent state dynamics, and project them into the PC-space for further analysis or visualization."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Game trajectory dynamics latent state processing\n",
    "def get_latent_trajectory(embeddings: torch.Tensor, actions: torch.Tensor, model) -> np.ndarray:\n",
    "    latent_state = embeddings[0].unsqueeze(0)\n",
    "    \n",
    "    latent_states = list()\n",
    "    latent_states.append(to_ndarray(latent_state.cpu()))\n",
    "    with torch.no_grad():\n",
    "        for i in range(len(actions)):\n",
    "            \n",
    "            network_output = model.recurrent_inference(latent_state, actions[i].unsqueeze(0))    # 这里action注意\n",
    "            latent_state, reward, value, policy_logits = mz_network_output_unpack(network_output)\n",
    "        \n",
    "            # memory = latent_state\n",
    "            latent_states.append(to_ndarray(latent_state.cpu()))\n",
    "\n",
    "    stacked = np.concatenate(latent_states)\n",
    "    return stacked"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "#   1 Trajectory loading\n",
    "real_state = np.array(trajectorys[0].obs_segment)\n",
    "real_state_tensor = to_device(to_tensor(real_state), state_space_tensor.device)\n",
    "actions = np.array(trajectorys[0].action_segment)\n",
    "actions_tensor = to_device(to_tensor(actions).unsqueeze(1), state_space_tensor.device)\n",
    "with torch.no_grad():\n",
    "    network_output = policy._model.initial_inference(real_state_tensor)\n",
    "latent_state_represent_tensor, reward, v_trajectorys_tensor, policy_logits = mz_network_output_unpack(network_output) \n",
    "latent_state_represent = to_ndarray(latent_state_represent_tensor.cpu())\n",
    "v_trajectorys = to_ndarray(v_trajectorys_tensor.cpu())\n",
    "\n",
    "#   2 Get latent state trajectory\n",
    "latent_state_dynamics = get_latent_trajectory(latent_state_represent_tensor, actions_tensor, policy._model)\n",
    "\n",
    "#   3 Project to PC-space\n",
    "pc_embedding_trajectory = pca.transform(latent_state_represent.reshape(len(latent_state_represent), -1))\n",
    "pc_dynamics_trajectory = pca.transform(latent_state_dynamics.reshape(len(latent_state_dynamics), -1))"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Observation distribution"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#   1 latent state trajectory distribution\n",
    "plt.violinplot(latent_state_represent.reshape(len(latent_state_represent), -1), np.arange(1, latent_state_space.shape[-1] + 1))\n",
    "plt.violinplot(latent_state_dynamics.reshape(len(latent_state_dynamics), -1), np.arange(1, latent_state_space.shape[-1] + 1))\n",
    "\n",
    "plt.scatter([], [], label='embedding')\n",
    "plt.scatter([], [], label='dynamics')\n",
    "\n",
    "plt.title(\"Value Distributions within latent-space\")\n",
    "\n",
    "plt.ylabel(\"Values\")\n",
    "plt.xlabel(\"Latent Dimension\")\n",
    "plt.xticks(range(1, latent_state_space.shape[-1] + 1), [f'dim {i}' for i in range(1, latent_state_space.shape[-1] + 1)], rotation=45)\n",
    "\n",
    "plt.legend()\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#   2 Latent state trajectory distribution after PCA\n",
    "plt.violinplot(pc_embedding_trajectory, np.arange(1, latent_state_space.shape[-1] + 1))\n",
    "plt.violinplot(pc_dynamics_trajectory, np.arange(1, latent_state_space.shape[-1] + 1))\n",
    "\n",
    "plt.scatter([], [], label='embedding')\n",
    "plt.scatter([], [], label='dynamics')\n",
    "\n",
    "plt.title(\"Value Distributions within latent PC-space\")\n",
    "\n",
    "plt.ylabel(\"Values\")\n",
    "plt.xlabel(\"Latent Dimension\")\n",
    "plt.xticks(range(1, latent_state_space.shape[-1] + 1), [f'dim {i}' for i in range(1, latent_state_space.shape[-1] + 1)], rotation=45)\n",
    "\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3D trajectory visualization\n",
    "\n",
    "This code segment provides functions for generating 3D visualizations using Plotly.\n",
    "\n",
    "The `generate_3d_surface` function creates a 3D surface plot.\n",
    "\n",
    "The `generate_3d_trajectory` function creates a 3D scatter plot for a trajectory.\n",
    "\n",
    "The `generate_3d_valuefield` function creates a 3D scatter plot for a value field.\n",
    "\n",
    "The code uses these functions to generate 3D visualizations.\n",
    "It creates a 3D trajectory plot (`dynamics_trajectory`) and an embedding trajectory plot (`embedding_trajectory`) using the PC-space trajectories.\n",
    "It also generates a 3D surface plot (`surface`) using the PCA-transformed latent state space and the corresponding values.\n",
    "\n",
    "This code provides a way to visualize 3D trajectories, embedding trajectories, and value fields using Plotly."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_3d_surface(x: np.ndarray, y: np.ndarray, z: np.ndarray, colors: np.ndarray, clim=None):\n",
    "    return go.Surface(\n",
    "        x=x, y=y, z=z,\n",
    "        opacity=1, \n",
    "        surfacecolor=colors,\n",
    "        colorscale='Viridis',\n",
    "        cmin=colors.min() if clim is None else clim[0],\n",
    "        cmax=colors.max() if clim is None else clim[1],\n",
    "        colorbar=dict(title=dict(text='V',side='top'), thickness=50, tickmode='array')\n",
    "    )\n",
    "\n",
    "def generate_3d_trajectory(x: np.ndarray, y: np.ndarray, z: np.ndarray, color: str):\n",
    "    return go.Scatter3d(\n",
    "        x=x + np.random.rand()*0.01,\n",
    "        y=y + np.random.rand()*0.01,\n",
    "        z=z + np.random.rand()*0.01,\n",
    "        mode='lines+markers',\n",
    "        marker=dict(\n",
    "            size=3,\n",
    "            symbol='x',\n",
    "            color=color,\n",
    "            opacity=1\n",
    "        ),\n",
    "        line=dict(\n",
    "            color=color,\n",
    "            width=20\n",
    "        )\n",
    "    )\n",
    "\n",
    "def generate_3d_valuefield(x: np.ndarray, y: np.ndarray, z: np.ndarray, colors: np.ndarray, clim=None):\n",
    "    return go.Scatter3d(\n",
    "        x=x, y=y, z=z,\n",
    "        mode='markers',\n",
    "        marker=dict(\n",
    "            size=4,\n",
    "            color=colors,\n",
    "            colorscale='Viridis',\n",
    "            cmin=colors.min() if clim is None else clim[0],\n",
    "            cmax=colors.max() if clim is None else clim[1],\n",
    "            opacity=1,\n",
    "            colorbar=dict(title=dict(text='V',side='top'), thickness=50, tickmode='array')\n",
    "        ),\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 3D trajectory visualization\n",
    "x = 3\n",
    "dynamics_trajectory =  generate_3d_trajectory(\n",
    "    pc_dynamics_trajectory[:, 0].ravel(),\n",
    "    pc_dynamics_trajectory[:, 1].ravel(), \n",
    "    pc_dynamics_trajectory[:, 2].ravel(), 'grey')\n",
    "\n",
    "embedding_trajectory = generate_3d_trajectory(\n",
    "    pc_embedding_trajectory[:, 0].ravel(), \n",
    "    pc_embedding_trajectory[:, 1].ravel(), \n",
    "    pc_embedding_trajectory[:, 2].ravel(), 'black')\n",
    "\n",
    "surface = generate_3d_valuefield(pca_latent_state_space[:,0], pca_latent_state_space[:,1], pca_latent_state_space[:,2], v_state_space)\n",
    "\n",
    "fig = go.Figure(data=[embedding_trajectory, dynamics_trajectory, surface])\n",
    "# \n",
    "# tight layout\n",
    "fig.update_layout(margin=dict(l=0, r=0, b=0, t=0))\n",
    "\n",
    "fig.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "DI-LZ",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
