from util.logger import logger

from typing import List

from omegaconf import DictConfig

import time

import concurrent.futures as cf

import shlex
import subprocess

import gc

from pathlib import Path

from tqdm.auto import tqdm

import numpy as np

import torch

import random

from util.basic_util import (
    get_global_variable, 
    is_none, 
    get_true_value
)
from util.image_util import save_pil_as_png
from util.yaml_util import (
    load_yaml, 
    save_yaml, 
    convert_numpy_type_to_native_type
)
from util.torch_util import get_latent
from util.pipeline_util import (
    load_pipeline, load_scheduler, 
    get_inference_step_minus_one, 
    get_pipeline_category_and_type, 
    img_latent_to_pil, 
    get_folder_name
)

from prompt_manager.util import get_prompt_manager

from task.util.seed_list_util import prepare_seed_list

from my_diffusers.scheduling_ddim import register_scheduling_ddim

from OT_MCTS.src.lru_cache import LRUCache

from diffusion_OT_MCTS.mdp.state_space import StateSpace
from diffusion_OT_MCTS.mdp.action_space_eta import ActionSpaceEta
from diffusion_OT_MCTS.mdp.action_space_eps import ActionSpaceEps
from diffusion_OT_MCTS.mdp.diffusion_mdp import DiffusionMDP
from diffusion_OT_MCTS.reward_model.util import get_reward_model
from diffusion_OT_MCTS.diffusion_ot_mcts import DiffusionOTMCTS


@torch.no_grad()
def run_optimal_control_mcts_implement(
    cfg: DictConfig
):
    cfg_dict = {}

    # ---------= [Basic Global Variables] =---------
    exp_name = get_global_variable("exp_name")
    start_time = get_global_variable("start_time")
    device = get_global_variable("device")
    seed = get_global_variable("seed")
    exp_time_str = f"{exp_name}_{start_time}"

    concurrent_max_worker = get_global_variable("concurrent_max_worker")

    vae_decode_batch_size = get_global_variable("vae_decode_batch_size")

    cfg_dict["global_variable"] = {
        "exp_name": exp_name, 
        "start_time": start_time, 
        "device": device, 
        "seed": seed, 
        "exp_time_str": exp_time_str, 
        "concurrent_max_worker": concurrent_max_worker, 
        "vae_decode_batch_size": vae_decode_batch_size
    }

    # ---------= [Pipeline] =---------
    logger(f"[Pipeline] Loading started. ")

    pipeline_type = get_true_value(cfg["pipeline"]["pipeline_type"])
    pipeline_path = get_true_value(cfg["pipeline"]["pipeline_path"])
    pipeline_torch_dtype = get_true_value(cfg["pipeline"]["torch_dtype"])
    pipeline_variant = get_true_value(cfg["pipeline"]["variant"])
    
    logger(f"    pipeline_type: {pipeline_type}")
    logger(f"    pipeline_path: {pipeline_path}")
    logger(f"    pipeline_torch_dtype: {pipeline_torch_dtype}")
    logger(f"    pipeline_variant: {pipeline_variant}")

    scheduler_type = get_true_value(cfg["pipeline"]["scheduler"])

    logger(f"    scheduler_type: {scheduler_type}")

    cfg_dict["pipeline"] = {
        "pipeline_type": pipeline_type, 
        "pipeline_path": pipeline_path, 
        "pipeline_torch_dtype": pipeline_torch_dtype, 
        "pipeline_variant": pipeline_variant, 

        "scheduler_type": scheduler_type
    },

    logger(
        f"[Pipeline] Loading finished. "
        "\n"
    )

    # ---------= [Eps Action] =---------
    task_name = get_true_value(cfg["task"]["name"])

    is_eps_action = True if task_name.startswith("search-run_optimal_control_mcts_eps") \
        else False

    # ---------= [Prompt List] =---------
    logger(f"[Prompt List] Loading started. ")

    num_prompt = get_true_value(cfg["task"]["prompt_list"]["num_prompt"])
    prompt_manager_dict = get_true_value(cfg["task"]["prompt_list"]["prompt_manager_dict"])

    logger(f"    num_prompt: {num_prompt}")
    logger(f"    prompt_manager_dict: {prompt_manager_dict}")

    logger(
        f"[Prompt List] Loading finished. "
        "\n"
    )

    # ---------= [Init Latent] =---------
    logger(f"[Init Latent] Loading started. ")

    init_latent_random = get_true_value(cfg["task"]["init_latent"]["random"])

    logger(f"    init_latent_random: {init_latent_random}")

    cfg_dict["init_latent"] = {
        "init_latent_random": init_latent_random
    }

    if init_latent_random:
        init_latent_seed_st = get_true_value(cfg["task"]["init_latent"]["seed_st"])
        init_latent_seed_ed = get_true_value(cfg["task"]["init_latent"]["seed_ed"])
        
        logger(f"    init_latent_seed_st: {init_latent_seed_st}")
        logger(f"    init_latent_seed_ed: {init_latent_seed_ed}")

        cfg_dict["init_latent"] = {
            "init_latent_seed_st": init_latent_seed_st, 
            "init_latent_seed_ed": init_latent_seed_ed
        }
    else:
        init_latent_seed_list = get_true_value(cfg["task"]["init_latent"]["seed_list"])
        init_latent_seed_auto_increment = get_true_value(cfg["task"]["init_latent"]["seed_auto_increment"])
        
        logger(f"    init_latent_seed_list: {init_latent_seed_list}")
        logger(f"    init_latent_seed_auto_increment: {init_latent_seed_auto_increment}")

        cfg_dict["init_latent"] = {
            "init_latent_seed_list": init_latent_seed_list, 
            "init_latent_seed_auto_increment": init_latent_seed_auto_increment
        }

    logger(
        f"[Init Latent] Loading finished. "
        "\n"
    )

    # ---------= [Eps] =---------
    logger(f"[Eps] Loading started. ")

    eps_random = get_true_value(cfg["task"]["eps"]["random"])

    logger(f"    eps_random: {eps_random}")

    cfg_dict["eps"] = {
        "eps_random": eps_random
    }

    if eps_random:
        eps_seed_st = get_true_value(cfg["task"]["eps"]["seed_st"])
        eps_seed_ed = get_true_value(cfg["task"]["eps"]["seed_ed"])
        
        logger(f"    eps_seed_st: {eps_seed_st}")
        logger(f"    eps_seed_ed: {eps_seed_ed}")

        cfg_dict["eps"] = {
            "eps_seed_st": eps_seed_st, 
            "eps_seed_ed": eps_seed_ed
        }
    else:
        eps_seed_list = get_true_value(cfg["task"]["eps"]["seed_list"])
        eps_seed_auto_increment = get_true_value(cfg["task"]["eps"]["seed_auto_increment"])
        
        logger(f"    eps_seed_list: {eps_seed_list}")
        logger(f"    eps_seed_auto_increment: {eps_seed_auto_increment}")

        cfg_dict["eps"] = {
            "eps_seed_list": eps_seed_list, 
            "eps_seed_auto_increment": eps_seed_auto_increment
        }

    logger(
        f"[Eps] Loading finished. "
        "\n"
    )

    # ---------= [Eta] =---------
    logger(f"[Eta] Loading started. ")

    eta_random = get_true_value(cfg["task"]["eta"]["random"])

    logger(f"    eta_random: {eta_random}")
    
    eta_low = get_true_value(cfg["task"]["eta"]["eta_low"])
    eta_high = get_true_value(cfg["task"]["eta"]["eta_high"])

    logger(f"    eta_low: {eta_low}")
    logger(f"    eta_high: {eta_high}")

    default_eta = None
    if not eta_random:
        default_eta = get_true_value(cfg["task"]["eta"]["default_eta"])

    logger(
        f"[Eta] Loading finished. "
        "\n"
    )

    cfg_dict["eta"] = {
        "eta_low": eta_low, 
        "eta_high": eta_high, 

        "default_eta": default_eta
    }

    # ---------= [Sample] =---------
    logger(f"[Sample] Loading started. ")

    prompt_2 = get_true_value(cfg["task"]["sample"]["prompt_2"])
    negative_prompt = get_true_value(cfg["task"]["sample"]["negative_prompt"])
    negative_prompt_2 = get_true_value(cfg["task"]["sample"]["negative_prompt_2"])

    logger(f"    prompt_2: {prompt_2}")
    logger(f"    negative_prompt: {negative_prompt}")
    logger(f"    negative_prompt_2: {negative_prompt_2}")

    height = get_true_value(cfg["task"]["sample"]["height"])
    width = get_true_value(cfg["task"]["sample"]["width"])
    down_sampling_ratio = get_true_value(cfg["task"]["sample"]["down_sampling_ratio"])

    logger(f"    height: {height}")
    logger(f"    width: {width}")
    logger(f"    down_sampling_ratio: {down_sampling_ratio}")

    num_inference_step = get_true_value(cfg["task"]["sample"]["num_inference_step"])
    guidance_scale = get_true_value(cfg["task"]["sample"]["guidance_scale"])

    logger(f"    num_inference_step: {num_inference_step}")
    logger(f"    guidance_scale: {guidance_scale}")

    logger(
        f"[Sample] Loading finished. "
        "\n"
    )

    cfg_dict["sample"] = {
        "prompt_2": prompt_2, 
        "negative_prompt": negative_prompt, 
        "negative_prompt_2": negative_prompt_2, 

        "height": height, 
        "width": width, 
        "down_sampling_ratio": down_sampling_ratio, 

        "num_inference_step": num_inference_step, 

        "guidance_scale": guidance_scale
    }

    # ---------= [Task] =---------
    logger(f"[Task] Loading started. ")
    
    num_sample_per_prompt = get_true_value(cfg["task"]["task"]["num_sample_per_prompt"])
    # prompt_batch_size = get_true_value(cfg["task"]["task"]["prompt_batch_size"])
    # batch_size = get_true_value(cfg["task"]["task"]["batch_size"])

    logger(f"    num_sample_per_prompt: {num_sample_per_prompt}")
    # logger(f"    prompt_batch_size: {prompt_batch_size}")
    # logger(f"    batch_size: {batch_size}")

    logger(
        f"[Task] Loading finished. "
        "\n"
    )

    cfg_dict["task"] = {
        "num_sample_per_prompt": num_sample_per_prompt
    }

    # ---------= [Promptist] =---------
    logger(f"[Promptist] Loading started. ")
    
    enable_promptist = get_true_value(cfg["task"]["promptist"]["enable"])
    cfg_yaml_path_promptist = get_true_value(cfg["task"]["promptist"]["cfg_yaml_path"])

    logger(f"    enable_promptist: {enable_promptist}")
    logger(f"    cfg_yaml_path_promptist: {cfg_yaml_path_promptist}")

    logger(
        f"[Promptist] Loading finished. "
        "\n"
    )

    cfg_dict["promptist"] = {
        "enable": enable_promptist, 
        "cfg_yaml_path": cfg_yaml_path_promptist
    }

    # ---------= [Golden Noise] =---------
    logger(f"[Golden Noise] Loading started. ")
    
    enable_golden_noise = get_true_value(cfg["task"]["golden_noise"]["enable"])
    cfg_yaml_path_golden_noise = get_true_value(cfg["task"]["golden_noise"]["cfg_yaml_path"])

    logger(f"    enable_golden_noise: {enable_golden_noise}")
    logger(f"    cfg_yaml_path_golden_noise: {cfg_yaml_path_golden_noise}")

    logger(
        f"[Golden Noise] Loading finished. "
        "\n"
    )

    cfg_dict["golden_noise"] = {
        "enable": enable_golden_noise, 
        "cfg_yaml_path": cfg_yaml_path_golden_noise
    }

    # ---------= [Save] =---------
    logger(f"[Save] Loading started. ")
    
    save_root_path = get_true_value(cfg["task"]["save"]["save_root_path"])

    logger(f"    save_root_path: {save_root_path}")

    logger(
        f"[Save] Loading finished. "
        "\n"
    )

    cfg_dict["save"] = {
        "save_root_path": save_root_path
    }

    # ---------= [Reward Model] =---------
    logger(f"[Reward Model] Loading started. ")
    
    reward_model_type = get_true_value(cfg["task"]["reward_model"]["reward_model_type"])
    cal_dynamics_batch_size = get_true_value(cfg["task"]["reward_model"]["cal_dynamics_batch_size"])
    cal_intermediate_reward_batch_size = get_true_value(cfg["task"]["reward_model"]["cal_intermediate_reward_batch_size"])
    cal_final_reward_batch_size = get_true_value(cfg["task"]["reward_model"]["cal_final_reward_batch_size"])
    # disable_intermediate_reward = get_true_value(cfg["task"]["reward_model"]["disable_intermediate_reward"])
    cal_intermediate_reward_policy = get_true_value(cfg["task"]["reward_model"]["cal_intermediate_reward_policy"])
    reward_shaping_policy = get_true_value(cfg["task"]["reward_model"]["reward_shaping_policy"])
    # potential_exp_growing = get_true_value(cfg["task"]["reward_model"]["potential_exp_growing"])
    # potential_exp_base = get_true_value(cfg["task"]["reward_model"]["potential_exp_base"])

    logger(f"    reward_model_type: {reward_model_type}")
    logger(f"    cal_dynamics_batch_size: {cal_dynamics_batch_size}")
    logger(f"    cal_intermediate_reward_batch_size: {cal_intermediate_reward_batch_size}")
    logger(f"    cal_final_reward_batch_size: {cal_final_reward_batch_size}")
    # logger(f"    disable_intermediate_reward: {disable_intermediate_reward}")
    logger(f"    cal_intermediate_reward_policy: {cal_intermediate_reward_policy}")
    logger(f"    reward_shaping_policy: {reward_shaping_policy}")
    # logger(f"    potential_exp_growing: {potential_exp_growing}")
    # logger(f"    potential_exp_base: {potential_exp_base}")

    cfg_dict["reward_model"] = {
        "reward_model_type": reward_model_type, 
        "cal_dynamics_batch_size": cal_dynamics_batch_size, 
        "cal_intermediate_reward_batch_size": cal_intermediate_reward_batch_size, 
        "cal_final_reward_batch_size": cal_final_reward_batch_size, 
        # "disable_intermediate_reward": disable_intermediate_reward, 
        "cal_intermediate_reward_policy": cal_intermediate_reward_policy, 
        "reward_shaping_policy": reward_shaping_policy, 
        # "potential_exp_growing": potential_exp_growing, 
        # "potential_exp_base": potential_exp_base
    }

    num_look_ahead_step = None
    gamma = None

    if cal_intermediate_reward_policy == "look_ahead":
        num_look_ahead_step = get_true_value(cfg["task"]["reward_model"]["num_look_ahead_step"])

        logger(f"    num_look_ahead_step: {num_look_ahead_step}")

        cfg_dict["reward_model"]["num_look_ahead_step"] = num_look_ahead_step
    elif cal_intermediate_reward_policy == "discount":
        gamma = get_true_value(cfg["task"]["reward_model"]["gamma"])

        logger(f"    gamma: {gamma}")

        cfg_dict["reward_model"]["gamma"] = gamma

    logger(
        f"[Reward Model] Loading finished. "
        "\n"
    )

    # ---------= [LRU Cache] =---------
    logger(f"[LRU Cache] Loading started. ")
    
    num_gpu_resident_lim = get_true_value(cfg["task"]["lru_cache"]["num_gpu_resident_lim"])

    logger(f"    num_gpu_resident_lim: {num_gpu_resident_lim}")

    logger(
        f"[LRU Cache] Loading finished. "
        "\n"
    )

    cfg_dict["lru_cache"] = {
        "num_gpu_resident_lim": num_gpu_resident_lim
    }

    # ---------= [MCTS] =---------
    logger(f"[MCTS] Loading started. ")
    
    mdp_modeling = get_true_value(cfg["task"]["mcts"]["mode"]["mdp_modeling"])
    value_policy = get_true_value(cfg["task"]["mcts"]["mode"]["value_policy"])
    pseudo_latent_as_final = get_true_value(cfg["task"]["mcts"]["mode"]["pseudo_latent_as_final"])
    enable_pseudo_latent_as_final_depth = get_true_value(cfg["task"]["mcts"]["mode"]["enable_pseudo_latent_as_final_depth"])
    exploration_coef = get_true_value(cfg["task"]["mcts"]["ucb"]["exploration_coef"])
    # depth_coef = get_true_value(cfg["task"]["mcts"]["ucb"]["depth_coef"])
    # exclude_last_intermediate_reward = get_true_value(cfg["task"]["mcts"]["ucb"]["exclude_last_intermediate_reward"])
    selection_depth_lim = get_true_value(cfg["task"]["mcts"]["selection"]["selection_depth_lim"])
    expansion_action_sampling_policy = get_true_value(cfg["task"]["mcts"]["expansion"]["expansion_action_sampling_policy"])
    expansion_enable_importance_sampling = get_true_value(cfg["task"]["mcts"]["expansion"]["enable_importance_sampling"])
    expansion_importance_sampling_J_star_scaling_factor = get_true_value(cfg["task"]["mcts"]["expansion"]["importance_sampling_J_star_scaling_factor"])
    expansion_importance_sampling_eps = get_true_value(cfg["task"]["mcts"]["expansion"]["importance_sampling_eps"])
    expansion_importance_sampling_verbose = get_true_value(cfg["task"]["mcts"]["expansion"]["importance_sampling_verbose"])
    per_iteration_expansion_lim = get_true_value(cfg["task"]["mcts"]["expansion"]["per_iteration_expansion_lim"])
    simulation_action_sampling_policy = get_true_value(cfg["task"]["mcts"]["simulation"]["simulation_action_sampling_policy"])
    simulation_default_action_list = get_true_value(cfg["task"]["mcts"]["simulation"]["default_action_list"])
    nfe_cal_dynamics_lim = get_true_value(cfg["task"]["mcts"]["nfe_limit"]["nfe_cal_dynamics_lim"])
    nfe_cal_intermediate_reward_lim = get_true_value(cfg["task"]["mcts"]["nfe_limit"]["nfe_cal_intermediate_reward_lim"])
    nfe_cal_final_reward_lim = get_true_value(cfg["task"]["mcts"]["nfe_limit"]["nfe_cal_final_reward_lim"])
    # beta_parameterization = get_true_value(cfg["task"]["mcts"]["beta"]["beta_parameterization"])

    logger(f"    mdp_modeling: {mdp_modeling}")
    logger(f"    value_policy: {value_policy}")
    logger(f"    pseudo_latent_as_final: {pseudo_latent_as_final}")
    logger(f"    enable_pseudo_latent_as_final_depth: {enable_pseudo_latent_as_final_depth}")
    logger(f"    exploration_coef: {exploration_coef}")
    # logger(f"    depth_coef: {depth_coef}")
    # logger(f"    exclude_last_intermediate_reward: {exclude_last_intermediate_reward}")
    logger(f"    selection_depth_lim: {selection_depth_lim}")
    logger(f"    expansion_action_sampling_policy: {expansion_action_sampling_policy}")
    logger(f"    expansion_enable_importance_sampling: {expansion_enable_importance_sampling}")
    logger(f"    expansion_importance_sampling_J_star_scaling_factor: {expansion_importance_sampling_J_star_scaling_factor}")
    logger(f"    expansion_importance_sampling_eps: {expansion_importance_sampling_eps}")
    logger(f"    expansion_importance_sampling_verbose: {expansion_importance_sampling_verbose}")
    logger(f"    per_iteration_expansion_lim: {per_iteration_expansion_lim}")
    logger(f"    simulation_action_sampling_policy: {simulation_action_sampling_policy}")
    logger(f"    simulation_default_action_list: {simulation_default_action_list}")
    logger(f"    nfe_cal_dynamics_lim: {nfe_cal_dynamics_lim}")
    logger(f"    nfe_cal_intermediate_reward_lim: {nfe_cal_intermediate_reward_lim}")
    logger(f"    nfe_cal_final_reward_lim: {nfe_cal_final_reward_lim}")
    # logger(f"    beta_parameterization: {beta_parameterization}")

    cfg_dict["mcts"] = {
        "mode": {
            "mdp_modeling": mdp_modeling, 
            "value_policy": value_policy, 
            "pseudo_latent_as_final": pseudo_latent_as_final, 
            "enable_pseudo_latent_as_final_depth": enable_pseudo_latent_as_final_depth
        }, 
        "ucb": {
            "exploration_coef": exploration_coef, 
            # "depth_coef": depth_coef, 
            # "exclude_last_intermediate_reward": exclude_last_intermediate_reward, 
        }, 
        "selection": {
            "selection_depth_lim": selection_depth_lim
        }, 
        "expansion": {
            "expansion_action_sampling_policy": expansion_action_sampling_policy, 
            "expansion_enable_importance_sampling": expansion_enable_importance_sampling, 
            "expansion_importance_sampling_J_star_scaling_factor": expansion_importance_sampling_J_star_scaling_factor, 
            "expansion_importance_sampling_eps": expansion_importance_sampling_eps, 
            "expansion_importance_sampling_verbose": expansion_importance_sampling_verbose, 
            "per_iteration_expansion_lim": per_iteration_expansion_lim
        }, 
        "simulation": {
            "simulation_action_sampling_policy": simulation_action_sampling_policy, 
            "simulation_default_action_list": simulation_default_action_list
        }, 
        "nfe_limit": {
            "nfe_cal_dynamics_lim": nfe_cal_dynamics_lim, 
            "nfe_cal_intermediate_reward_lim": nfe_cal_intermediate_reward_lim, 
            "nfe_cal_final_reward_lim": nfe_cal_final_reward_lim
        }, 
        "beta": {
            # "beta_parameterization": beta_parameterization
        }
    }

    # optimal_control_online_update = None
    # optimal_control_update_reward_threshold = None
    # optimal_control_omega_z = None
    # optimal_control_omega_eta = None
    # optimal_control_finite_difference_accuracy_order = None
    # optimal_control_finite_difference_eps = None
    # optimal_control_force_positive_semi_definite_max_tolerance = None
    # optimal_control_force_positive_definite_max_tolerance = None

    # if expansion_action_sampling_policy == "optimal_control_beta":
    #     optimal_control_online_update = get_true_value(cfg["task"]["mcts"]["optimal_control_online_update"])
    #     optimal_control_update_reward_threshold = get_true_value(cfg["task"]["mcts"]["optimal_control_update_reward_threshold"])
    #     optimal_control_omega_z = get_true_value(cfg["task"]["mcts"]["optimal_control_omega_z"])
    #     optimal_control_omega_eta = get_true_value(cfg["task"]["mcts"]["optimal_control_omega_eta"])
    #     optimal_control_finite_difference_accuracy_order = get_true_value(cfg["task"]["mcts"]["optimal_control_finite_difference_accuracy_order"])
    #     optimal_control_finite_difference_eps = get_true_value(cfg["task"]["mcts"]["optimal_control_finite_difference_eps"])
    #     optimal_control_force_positive_semi_definite_max_tolerance = get_true_value(cfg["task"]["mcts"]["optimal_control_force_positive_semi_definite_max_tolerance"])
    #     optimal_control_force_positive_definite_max_tolerance = get_true_value(cfg["task"]["mcts"]["optimal_control_force_positive_definite_max_tolerance"])

    #     logger(f"    optimal_control_online_update: {optimal_control_online_update}")
    #     logger(f"    optimal_control_update_reward_threshold: {optimal_control_update_reward_threshold}")
    #     logger(f"    optimal_control_omega_z: {optimal_control_omega_z}")
    #     logger(f"    optimal_control_omega_eta: {optimal_control_omega_eta}")
    #     logger(f"    optimal_control_finite_difference_accuracy_order: {optimal_control_finite_difference_accuracy_order}")
    #     logger(f"    optimal_control_finite_difference_eps: {optimal_control_finite_difference_eps}")
    #     logger(f"    optimal_control_force_positive_semi_definite_max_tolerance: {optimal_control_force_positive_semi_definite_max_tolerance}")
    #     logger(f"    optimal_control_force_positive_definite_max_tolerance: {optimal_control_force_positive_definite_max_tolerance}")

    beta_online_update = None
    beta_update_policy = None
    beta_value_gradient_update_time = None
    beta_update_step_size = None
    beta_max_update_bias = None
    beta_action_bias = None
    beta_zeta_list = None
    beta_update_reward_threshold = None
    beta_clamp_eps = None
    beta_direction_length_eps = None

    if expansion_action_sampling_policy == "beta":
        beta_online_update = get_true_value(cfg["task"]["mcts"]["beta"]["online_update"])
        beta_update_policy = get_true_value(cfg["task"]["mcts"]["beta"]["update_policy"])
        beta_value_gradient_update_time = get_true_value(cfg["task"]["mcts"]["beta"]["value_gradient_update_time"])
        beta_update_step_size = get_true_value(cfg["task"]["mcts"]["beta"]["update_step_size"])
        beta_action_bias = get_true_value(cfg["task"]["mcts"]["beta"]["action_bias"])
        beta_max_update_bias = get_true_value(cfg["task"]["mcts"]["beta"]["max_update_bias"])
        beta_zeta_list = get_true_value(cfg["task"]["mcts"]["beta"]["zeta_list"])
        beta_update_reward_threshold = get_true_value(var=cfg["task"]["mcts"]["beta"]["update_reward_threshold"])
        beta_clamp_eps = get_true_value(cfg["task"]["mcts"]["beta"]["clamp_eps"])
        beta_direction_length_eps = get_true_value(cfg["task"]["mcts"]["beta"]["direction_length_eps"])

        logger(f"    beta_online_update: {beta_online_update}")
        logger(f"    beta_update_policy: {beta_update_policy}")
        logger(f"    beta_value_gradient_update_time: {beta_value_gradient_update_time}")
        logger(f"    beta_update_step_size: {beta_update_step_size}")
        logger(f"    beta_action_bias: {beta_action_bias}")
        logger(f"    beta_max_update_bias: {beta_max_update_bias}")
        logger(f"    beta_zeta_list: {beta_zeta_list}")
        logger(f"    beta_update_reward_threshold: {beta_update_reward_threshold}")
        logger(f"    beta_clamp_eps: {beta_clamp_eps}")
        logger(f"    beta_direction_length_eps: {beta_direction_length_eps}")

        cfg_dict["mcts"]["beta"] = {
            "beta_online_update": beta_online_update, 
            "beta_update_policy": beta_update_policy, 
            "beta_value_gradient_update_time": beta_value_gradient_update_time, 
            "beta_action_bias": beta_action_bias, 
            "beta_update_step_size": beta_update_step_size, 
            "beta_max_update_bias": beta_max_update_bias, 
            "beta_zeta_list": beta_zeta_list, 
            "beta_update_reward_threshold": beta_update_reward_threshold, 
            "beta_clamp_eps": beta_clamp_eps, 
            "beta_direction_length_eps": beta_direction_length_eps
        }

    logger(
        f"[MCTS] Loading finished. "
        "\n"
    )

    # ---------= [Display] =---------
    logger(f"[Display] Loading started. ")
    
    display_trajectory = get_true_value(cfg["task"]["display"]["display_trajectory"])
    display_selected_node_depth = get_true_value(cfg["task"]["display"]["display_selected_node_depth"])
    display_cal_state_value = get_true_value(cfg["task"]["display"]["display_cal_state_value"])
    display_reward_sum_to_leaf = get_true_value(cfg["task"]["display"]["display_reward_sum_to_leaf"])
    display_beta_mode_update = get_true_value(cfg["task"]["display"]["display_beta_mode_update"])

    logger(f"    display_trajectory: {display_trajectory}")
    logger(f"    display_selected_node_depth: {display_selected_node_depth}")
    logger(f"    display_cal_state_value: {display_cal_state_value}")
    logger(f"    display_reward_sum_to_leaf: {display_reward_sum_to_leaf}")
    logger(f"    display_beta_mode_update: {display_beta_mode_update}")

    logger(
        f"[Display] Loading finished. "
        "\n"
    )

    cfg_dict["display"] = {
        "display_trajectory": display_trajectory, 
        "display_selected_node_depth": display_selected_node_depth, 
        "display_cal_state_value": display_cal_state_value, 
        "display_reward_sum_to_leaf": display_reward_sum_to_leaf, 
        "display_beta_mode_update": display_beta_mode_update
    }

    # ---------= [All Components Loaded] =---------
    logger(
        f"All components loaded. "
        "\n"
    )

    # ---------= [Check Input] =---------
    if reward_shaping_policy != "latent_reward":
        logger(
            f"Only support `reward_shaping_policy = 'latent_reward'` in beam search. ", 

            log_type = "error"
        )

        breakpoint()

    if is_eps_action:
        if expansion_action_sampling_policy != "uniform":
            logger(
                f"Only support `expansion_action_sampling_policy = 'uniform'` in beam search. ", 

                log_type = "error"
            )

            breakpoint()

        if simulation_action_sampling_policy != "uniform":
            logger(
                f"Only support `simulation_action_sampling_policy = 'uniform'` in beam search. ", 

                log_type = "error"
            )

            breakpoint()

    # ---------= [Prepare Pipeline] =---------
    (
        pipeline_category_name, 
        pipeline_type_name
    ) = get_pipeline_category_and_type(
        pipeline_path = pipeline_path
    )

    pipeline = load_pipeline(
        pipeline_type = pipeline_type, 
        pipeline_path = pipeline_path, 
        torch_dtype = pipeline_torch_dtype, 
        variant = pipeline_variant
    )

    pipeline.scheduler = load_scheduler(
        pipeline = pipeline, 
        scheduler_type = scheduler_type
    )
    
    inference_step_minus_one = get_inference_step_minus_one(scheduler_type)

    if pipeline_category_name == "sd_family":
        from my_diffusers.pipeline_stable_diffusion import register_pipeline_stable_diffusion

        # register custom methods: `forward()`, `prepare_everything()`, `step()`
        register_pipeline_stable_diffusion(pipeline)
    elif pipeline_category_name == "sdxl_family":
        from my_diffusers.pipeline_stable_diffusion_xl import register_pipeline_stable_diffusion_xl

        # register custom methods: `forward()`, `prepare_everything()`, `step()`
        register_pipeline_stable_diffusion_xl(pipeline)
    elif pipeline_category_name == "pixart_alpha_family":
        from my_diffusers.pipeline_pixart_alpha import register_pipeline_pixart_alpha

        # register custom methods: `forward()`, `prepare_everything()`, `get_noise_pred()`, `step()`
        register_pipeline_pixart_alpha(pipeline)
    else:
        raise NotImplementedError(
            f"Unsupported `pipeline_category_name`, got `{pipeline_category_name}`. "
        )
    
    register_scheduling_ddim(
        scheduler = pipeline.scheduler
    )

    # move to GPU after registration
    pipeline = pipeline.to(device)
    
    # save VRAM by offloading the model to CPU
    pipeline.enable_model_cpu_offload()
    
    logger(f"    pipeline: {type(pipeline)}")
    # logger(f"    pipeline: {pipeline}")

    cfg_dict["pipeline"] = {
        "pipeline_category_name": pipeline_category_name, 
        "pipeline_type_name": pipeline_type_name, 
        "inference_step_minus_one": inference_step_minus_one
    }

    # ---------= [Prompt List] =---------
    prompt_manager = get_prompt_manager(
        prompt_manager_dict = prompt_manager_dict
    )

    prompt_manager.load_prompt_list()
    prompt_manager.prepare_everything(
        shuffle = True
    )

    prompt_list = list(prompt_manager.prompt_list)
    folder_name_list = prompt_manager.folder_name_list
    
    if num_prompt is not None:
        prompt_list = prompt_list[: num_prompt]
        folder_name_list = folder_name_list[: num_prompt]
    else:
        num_prompt = len(prompt_list)
    
    logger(f"prompt_list: {prompt_list}")

    # ---------= [Promptist] =---------
    optimized_prompt_list = None

    if enable_promptist:
        from ..util.promptist_util import (
            load_promptist_model, 
            optimize_text_prompt
        )

        (
            promptist_model, 
            promptist_tokenizer
        ) = load_promptist_model(
            cfg_yaml_path = cfg_yaml_path_promptist, 

            device = device
        )

        optimized_prompt_list = [
            optimize_text_prompt(
                prompt = prompt, 

                promptist_model = promptist_model, 
                promptist_tokenizer = promptist_tokenizer
            ) \
                for prompt in prompt_list
        ]

        # NB: 
        # `folder_name_list` is still generated with the original prompts

        # ---------= [Clean Up] =---------
        del promptist_model, promptist_tokenizer
        torch.cuda.empty_cache()
        gc.collect()

        pass

    # # prompt_tuple_list: (prompt_idx, prompt)
    # # prompt_tuple_list.shape = (num_prompt, 2)
    # prompt_tuple_list = [
    #     (i, prompt) \
    #         for i, prompt in enumerate(prompt_list)
    # ]

    # # prompt_tuple_list.shape = (num_prompt * num_sample_per_prompt, 2)
    # prompt_tuple_list = np.repeat(
    #     np.array(
    #         prompt_tuple_list, 
    #         dtype = object
    #     ), 
    #     num_sample_per_prompt, 

    #     axis = 0
    # ).tolist()

    # ---------= [Prepare Task] =---------
    num_sample = num_prompt * num_sample_per_prompt

    # num_batch = (num_sample + batch_size - 1) // batch_size

    cfg_dict["task"] = {
        "num_prompt": num_prompt, 
        "num_sample": num_sample
    }

    # ---------= [Prepare Init Latent] =---------
    latent_num_channel = 4
    latent_height = height // down_sampling_ratio
    latent_width = width // down_sampling_ratio

    latent_shape = (latent_num_channel, latent_height, latent_width)
    # latent_shape_str = f"{latent_num_channel}_{latent_height}_{latent_width}"

    if not init_latent_random:
        init_latent_seed_list = prepare_seed_list(
            seed_list = init_latent_seed_list, 
            target_length = num_sample_per_prompt, 

            auto_inrement = init_latent_seed_auto_increment
        )

        cfg_dict["init_latent"]["init_latent_seed_list"] = init_latent_seed_list
        
        init_latent_list = [
            get_latent(
                shape = (latent_num_channel, latent_height, latent_width), 
                
                seed = init_latent_seed_list[sample_idx % num_sample_per_prompt], 

                device = "cpu", 

                dtype = pipeline_torch_dtype
            ) \
                for sample_idx in range(num_sample)
        ]
    else:
        raise NotImplementedError(
            f"Only support `init_latent_random = False`. "
        )

    # init_latent_list.shape = (num_sample, latent_num_channel, latent_height, latent_width)
    init_latent_list = torch.stack(init_latent_list)

    # ---------= [Golden Noise] =---------
    if enable_golden_noise:
        from ..util.golden_noise.util import (
            load_golden_noise_model, 
            optimize_initial_noise
        )

        if pipeline_category_name == "sdxl_family":
            npnet = load_golden_noise_model(
                cfg_yaml_path = cfg_yaml_path_golden_noise, 

                sd_type = "SDXL", 

                device = device
            )
        else:
            raise NotImplementedError(
                f"Only support SDXL pipeline. "
            )

        if enable_promptist:
            duplicated_prompt_list = [
                optimized_prompt_list[sample_idx % num_sample_per_prompt] \
                    for sample_idx in range(num_sample)
            ]
        else:
            duplicated_prompt_list = [
                prompt_list[sample_idx % num_sample_per_prompt] \
                    for sample_idx in range(num_sample)
            ]

        optimized_init_latent_list = optimize_initial_noise(
            initial_noise_list = init_latent_list, 

            npnet = npnet, 

            prompt_list = duplicated_prompt_list, 

            pipeline = pipeline
        )

        init_latent_list = init_latent_list.cpu()

        # ---------= [Clean Up] =---------
        del duplicated_prompt_list
        del npnet
        torch.cuda.empty_cache()
        gc.collect()

        pass

    # ---------= [Prepare Eps] =---------
    eps_list = None
    if (not is_eps_action) and (not eps_random):
        eps_seed_list = prepare_seed_list(
            seed_list = eps_seed_list, 
            target_length = num_inference_step, 

            auto_inrement = eps_seed_auto_increment
        )

        cfg_dict["eps"]["eps_seed_list"] = eps_seed_list

        eps_list = [
            get_latent(
                shape = (latent_num_channel, latent_height, latent_width), 
                
                seed = seed, 

                device = device, 

                dtype = pipeline_torch_dtype
            ) \
                for seed in eps_seed_list
        ]

    # ---------= [Prepare Save Sample] =---------
    save_root_path = Path(save_root_path)
    pipeline_root_path = save_root_path / pipeline_type_name

    dataset_type_str = prompt_manager_dict["prompt_manager_type"]
    dataset_root_path = pipeline_root_path / dataset_type_str

    reward_model_type_str = reward_model_type
    if reward_model_type == "color_channel_reward":
        reward_model_cfg_yaml_path = Path("./config/reward_model/color_channel_reward.yaml")
        reward_model_cfg_dict = load_yaml(reward_model_cfg_yaml_path)
        target_channel_idx = reward_model_cfg_dict["color_channel_reward"]["target_channel_idx"]
        reward_model_type_str = f"{reward_model_type_str}-{target_channel_idx}"
    elif reward_model_type == "laplacian_var_reward":
        reward_model_cfg_yaml_path = Path("./config/reward_model/laplacian_var_reward.yaml")
        reward_model_cfg_dict = load_yaml(reward_model_cfg_yaml_path)
        tsfm_to_gray = reward_model_cfg_dict["laplacian_var_reward"]["tsfm_to_gray"]
        if tsfm_to_gray:
            reward_model_type_str = f"{reward_model_type_str}-gray"
        else:
            reward_model_type_str = f"{reward_model_type_str}-rgb"
    elif reward_model_type == "compressibility_hps_v2":
        reward_model_cfg_yaml_path = Path("./config/reward_model/compressibility_hps_v2.yaml")
        reward_model_cfg_dict = load_yaml(reward_model_cfg_yaml_path)
        lam = reward_model_cfg_dict["compressibility_hps_v2"]["lam"]
        reward_model_type_str = f"{reward_model_type_str}-{lam}"

    reward_model_root_path = dataset_root_path / reward_model_type_str
    reward_shaping_policy_root_path = reward_model_root_path / reward_shaping_policy

    mode_root_path = reward_shaping_policy_root_path / f"{mdp_modeling}_{value_policy}"
    if pseudo_latent_as_final:
        if enable_pseudo_latent_as_final_depth is None:
            enable_pseudo_latent_as_final_depth_str = "1"
        else:
            enable_pseudo_latent_as_final_depth_str = f"{enable_pseudo_latent_as_final_depth}"

        mode_root_path = mode_root_path / f"pseudo_latent_as_final-{enable_pseudo_latent_as_final_depth_str}"
    else:
        mode_root_path = mode_root_path / f"final_latent_only"

    inference_step_root_path = mode_root_path / f"{num_inference_step}"
    exploration_root_path = inference_step_root_path / f"{exploration_coef}"

    if selection_depth_lim is None:
        selection_depth_lim_str = f"{num_inference_step - 1}"
    else:
        selection_depth_lim_str = f"{selection_depth_lim}"
    selection_depth_root_path = exploration_root_path / f"{selection_depth_lim_str}"

    # if exclude_last_intermediate_reward:
    #     ucb_root_path = ucb_root_path / "exclude_last_intermediate_reward"
    # else:
    #     ucb_root_path = ucb_root_path / "include_last_intermediate_reward"
    # expansion_policy_root_path = selection_depth_root_path / expansion_action_sampling_policy
    # expansion_policy_root_path = selection_depth_root_path

    if expansion_action_sampling_policy == "beta":
        beta_policy_str = f"beta"
        beta_policy_str = f"{beta_policy_str}-{beta_update_policy}"
        if beta_update_policy == "value_gradient":
            beta_policy_str = f"{beta_policy_str}_{beta_value_gradient_update_time}"
        beta_policy_str = f"{beta_policy_str}-{beta_zeta_list}"

        if not beta_online_update:
            beta_policy_str = beta_policy_str = f"{beta_policy_str}-no_update"

        expansion_policy_root_path = selection_depth_root_path / beta_policy_str
    else:
        expansion_policy_root_path = selection_depth_root_path / expansion_action_sampling_policy

    if cal_intermediate_reward_policy == "look_ahead":
        cal_interemediate_reward_policy_root_path = expansion_policy_root_path / f"look_ahead-{num_look_ahead_step}"
    elif cal_intermediate_reward_policy == "discount":
        cal_interemediate_reward_policy_root_path = expansion_policy_root_path / f"discount-{gamma}"
    else:
        cal_interemediate_reward_policy_root_path = expansion_policy_root_path / cal_intermediate_reward_policy

    exp_root_path \
        = cal_interemediate_reward_policy_root_path / exp_time_str

    cfg_dict["save"]["exp_root_path"] = exp_root_path

    # ---------= [Prepare MDP] =---------
    action_shape = (1, )

    state_space = StateSpace(
        shape = latent_shape, 

        dtype = pipeline_torch_dtype, 
        device = device, 

        ver = "torch"
    )

    action_space_eta = None
    if (not is_eps_action) or eta_random:
        action_space_eta = ActionSpaceEta(
            eta_low = eta_low, 
            eta_high = eta_high, 

            shape = action_shape, 

            dtype = pipeline_torch_dtype, 
            device = device, 

            ver = "torch"
        )

    action_space_eps = None
    if is_eps_action or eps_random:
        action_space_eps = ActionSpaceEps(
            eps_seed_low = eps_seed_st, 
            eps_seed_high = eps_seed_ed, 

            shape = action_shape, 

            dtype = pipeline_torch_dtype, 
            device = device, 

            ver = "torch"
        )

    reward_shape = (1, )

    diffusion_mdp = DiffusionMDP(
        # ---------= [Beam Search] =---------
        is_eps_action = is_eps_action, 
        
        # ---------= [Eta] =---------
        eta_random = eta_random, 
        default_eta = default_eta, 

        state_space = state_space, 
        action_space_eta = action_space_eta, 
        action_space_eps = action_space_eps, 
        time_horizon = num_inference_step, 

        # ---------= [Parallel] =---------
        # cal_dynamics_batch_size = cal_dynamics_batch_size, 
        # cal_intermediate_reward_batch_size = cal_intermediate_reward_batch_size, 
        # cal_final_reward_batch_size = cal_final_reward_batch_size, 

        # ---------= [Reward] =---------
        reward_shape = reward_shape, 

        # ---------= [Pipeline] =---------
        pipeline = pipeline, 
        eps_list = eps_list, 

        cal_intermediate_reward_policy = cal_intermediate_reward_policy
    )

    cfg_dict["eta"]["action_shape"] = action_shape
    cfg_dict["reward_model"]["reward_shape"] = reward_shape

    # ---------= [Prepare Everything] =---------
    diffusion_mdp.prepare_everything(
        prompt_list = prompt_list, 
        optimized_prompt_list = optimized_prompt_list, 

        prompt_2 = prompt_2, 
        negative_prompt = negative_prompt, 
        negative_prompt_2 = negative_prompt_2, 

        height = height, width = width, 
        guidance_scale = guidance_scale, 

        init_latent_list = init_latent_list, 

        num_sample_per_prompt = num_sample_per_prompt, 

        inference_step_minus_one = inference_step_minus_one
    )

    # init_latent_list.shape = (num_sample_per_prompt, latent_num_channel, latent_height, latent_width)
    init_latent_list = diffusion_mdp.init_latent_list

    # ---------= [Prepare Reward Model] =---------
    prompt_emb_list = diffusion_mdp.prompt_emb_list
    param_dict = diffusion_mdp.param_dict
    
    reward_model = get_reward_model(
        reward_model_type = reward_model_type, 

        # ---------= [Pipeline] =---------
        pipeline = pipeline, 
        num_inference_step = num_inference_step, 

        # ---------= [Param] =---------
        prompt_emb_list = prompt_emb_list, 
        param_dict = param_dict, 
        num_sample_per_prompt = num_sample_per_prompt, 

        # ---------= [Reward] =---------
        reward_shape = reward_shape, 
        reward_dtype = pipeline_torch_dtype, 
        offload_to_cpu = True, 

        # ---------= [Parallel] =---------
        cal_dynamics_batch_size = cal_dynamics_batch_size, 
        cal_intermediate_reward_batch_size = cal_intermediate_reward_batch_size, 
        cal_final_reward_batch_size = cal_final_reward_batch_size, 

        # ---------= [Reward Shaping] =---------
        reward_shaping_policy = reward_shaping_policy, 
        # potential_exp_growing = potential_exp_growing, 
        # potential_exp_base = potential_exp_base, 
        cal_intermediate_reward_policy = cal_intermediate_reward_policy, 

        device = device, 
        
        vae_decode_batch_size = vae_decode_batch_size
    )
    
    cal_intermediate_reward_arg_dict = {}

    if cal_intermediate_reward_policy == "look_ahead":
        cal_intermediate_reward_arg_dict["num_look_ahead_step"] = num_look_ahead_step
    elif cal_intermediate_reward_policy == "discount":
        cal_intermediate_reward_arg_dict["gamma"] = gamma

    diffusion_mdp.set_reward_model(
        reward_model = reward_model, 

        **cal_intermediate_reward_arg_dict
    )

    if reward_model_type == "color_channel_reward":
        target_channel_idx = reward_model.target_channel_idx
        cfg_dict["reward_model"]["target_channel_idx"] = target_channel_idx

    # ---------= [Prepare LRU Cache] =---------
    lru_cache = LRUCache(
        num_gpu_resident_lim = num_gpu_resident_lim, 

        device = device
    )

    # ---------= [Diffusion MCTS] =---------
    arg_dict = {}

    if cal_intermediate_reward_policy == "look_ahead":
        arg_dict["num_look_ahead_step"] = num_look_ahead_step
    elif cal_intermediate_reward_policy == "discount":
        arg_dict["gamma"] = gamma

    diffusion_mcts = DiffusionOTMCTS(
        is_eps_action = is_eps_action, 

        mdp = diffusion_mdp, 
        init_state_list = init_latent_list, 

        mdp_modeling = mdp_modeling, 
        value_policy = value_policy, 
        pseudo_latent_as_final = pseudo_latent_as_final, 
        enable_pseudo_latent_as_final_depth = enable_pseudo_latent_as_final_depth, 

        # ---------= [Upper Confidence Bound (UCB)] =---------
        exploration_coef = exploration_coef, 
        # depth_coef = depth_coef, 
        # exclude_last_intermediate_reward = exclude_last_intermediate_reward, 

        # ---------= [Selection Policy] =---------
        selection_depth_lim = selection_depth_lim, 

        # ---------= [Expansion Policy] =---------
        expansion_action_sampling_policy = expansion_action_sampling_policy, 
        expansion_enable_importance_sampling = expansion_enable_importance_sampling, 
        expansion_importance_sampling_J_star_scaling_factor = expansion_importance_sampling_J_star_scaling_factor, 
        expansion_importance_sampling_eps = expansion_importance_sampling_eps, 
        expansion_importance_sampling_verbose = expansion_importance_sampling_verbose, 
        per_iteration_expansion_lim = per_iteration_expansion_lim, 

        # ---------= [Simulation Policy] =---------
        simulation_action_sampling_policy = simulation_action_sampling_policy, 
        simulation_default_action_list = simulation_default_action_list, 

        # ---------= [NFE Limit] =---------
        nfe_cal_dynamics_lim = nfe_cal_dynamics_lim, 
        nfe_cal_intermediate_reward_lim = nfe_cal_intermediate_reward_lim, 
        nfe_cal_final_reward_lim = nfe_cal_final_reward_lim, 

        # ---------= [Optimal Control] =---------
        # optimal_control_online_update = optimal_control_online_update, 
        # optimal_control_update_reward_threshold = optimal_control_update_reward_threshold, 
        # optimal_control_omega_z = optimal_control_omega_z, 
        # optimal_control_omega_eta = optimal_control_omega_eta, 
        # optimal_control_finite_difference_accuracy_order = optimal_control_finite_difference_accuracy_order, 
        # optimal_control_finite_difference_eps = optimal_control_finite_difference_eps, 
        # optimal_control_force_positive_semi_definite_max_tolerance = optimal_control_force_positive_semi_definite_max_tolerance, 
        # optimal_control_force_positive_definite_max_tolerance = optimal_control_force_positive_definite_max_tolerance, 
        # optimal_control_clamp_eps = optimal_control_clamp_eps, 

        # ---------= [Beta Distribution Parameterization] =---------
        beta_online_update = beta_online_update, 
        beta_update_policy = beta_update_policy, 
        beta_value_gradient_update_time = beta_value_gradient_update_time, 
        beta_action_bias = beta_action_bias, 
        beta_update_step_size = beta_update_step_size, 
        beta_max_update_bias = beta_max_update_bias, 
        beta_zeta_list = beta_zeta_list, 
        beta_update_reward_threshold = beta_update_reward_threshold, 
        beta_clamp_eps = beta_clamp_eps,  
        beta_direction_length_eps = beta_direction_length_eps,  
        display_beta_mode_update = display_beta_mode_update, 

        # ---------= [Optimal Control Beta] =---------
        lru_cache = lru_cache, 

        # ---------= [Dtype] =---------
        dtype = pipeline_torch_dtype, 

        # ---------= [Save Root Path] =---------
        expansion_policy_root_path = exp_root_path, 
        folder_name_list = folder_name_list, 
        cfg_dict = cfg_dict, 

        **arg_dict
    )

    # ---------= [Run Diffusion MCTS] =---------
    diffusion_mcts.run(
        display_result = True, 

        display_trajectory = display_trajectory, 
        display_state = False, 
        display_action = True, 
        display_reward = True, 

        display_cal_state_value = display_cal_state_value, 
        display_selected_node_depth = display_selected_node_depth, 
        display_reward_sum_to_leaf = display_reward_sum_to_leaf
    )

    # ---------= [Work] =---------
    # def _implement_batch(
    #     batch_idx: int
    # ):
    #     batch_prompt_tuple_list = prompt_tuple_list[
    #         batch_idx * prompt_batch_size: 
    #         min((batch_idx + 1) * prompt_batch_size, num_prompt)
    #     ]

    #     logger(f"[Batch {batch_idx}]")
    #     logger(f"    batch_prompt_tuple_list: {batch_prompt_tuple_list}")

    #     with cf.ThreadPoolExecutor(
    #         max_workers = concurrent_max_worker
    #     ) as executor:
    #         for (prompt_idx, prompt) in batch_prompt_tuple_list:
    #             # prompt = shlex.quote(prompt)
    #             prompt = json.dumps(prompt)
                
    #             cmd = [
    #                 "python", 
    #                 "main.py", 

    #                 f"pipeline={pipeline_type_name}", 
    #                 f"task=search/optimal_control_mcts/{pipeline_type_name}/template", 

    #                 # init_latent
    #                 f"task.init_latent.root_path={init_latent_root_path}", 
    #                 f"task.init_latent.random={init_latent_random}", 
    #                 f"task.init_latent.seed_st={init_latent_seed_st}", 
    #                 f"task.init_latent.seed_ed={init_latent_seed_ed}", 
    #                 f"task.init_latent.seed_list={init_latent_seed_list}", 
    #                 f"task.init_latent.seed_auto_increment={init_latent_seed_auto_increment}", 

    #                 # eps
    #                 f"task.eps.root_path={eps_root_path}", 
    #                 f"task.eps.random={eps_random}", 
    #                 f"task.eps.seed_st={eps_seed_st}", 
    #                 f"task.eps.seed_ed={eps_seed_ed}", 
    #                 f"task.eps.seed_list={eps_seed_list}", 
    #                 f"task.eps.seed_auto_increment={eps_seed_auto_increment}", 
                    
    #                 # sample
    #                 f"task.sample.prompt={prompt}", 
    #                 f"task.sample.prompt_2={prompt_2}", 
    #                 f"task.sample.negative_prompt={negative_prompt}", 
    #                 f"task.sample.negative_prompt_2={negative_prompt_2}", 
    #                 f"task.sample.height={height}", 
    #                 f"task.sample.width={width}", 
    #                 f"task.sample.down_sampling_ratio={down_sampling_ratio}", 
    #                 f"task.sample.num_inference_step={num_inference_step}", 
    #                 f"task.sample.guidance_scale={guidance_scale}", 

    #                 # task
    #                 f"task.task.num_sample={num_sample_per_prompt}", 
    #                 f"task.task.batch_size={batch_size}", 

    #                 # save
    #                 f"task.save.save_root_path={save_root_path}", 

    #                 # reward_model
    #                 f"task.reward_model.reward_model_type={reward_model_type}", 
    #                 f"task.reward_model.cal_dynamics_batch_size={cal_dynamics_batch_size}", 
    #                 f"task.reward_model.cal_intermediate_reward_batch_size={cal_intermediate_reward_batch_size}", 
    #                 f"task.reward_model.cal_final_reward_batch_size={cal_final_reward_batch_size}", 
    #                 f"task.reward_model.disable_intermediate_reward={disable_intermediate_reward}", 
    #                 f"task.reward_model.cal_intermediate_reward_policy={cal_intermediate_reward_policy}", 
    #                 f"task.reward_model.num_look_ahead_step={num_look_ahead_step}", 
    #                 f"task.reward_model.gamma={gamma}", 
    #                 f"task.reward_model.use_difference_reward={use_difference_reward}", 

    #                 # action_space
    #                 f"task.action_space.eta_low={eta_low}", 
    #                 f"task.action_space.eta_high={eta_high}", 

    #                 # lru_cache
    #                 f"task.lru_cache.num_gpu_resident_lim={num_gpu_resident_lim}", 

    #                 # mcts
    #                 f"task.mcts.exploration_coef={exploration_coef}", 
    #                 f"task.mcts.depth_coef={depth_coef}", 
    #                 f"task.mcts.expansion_action_sampling_policy={expansion_action_sampling_policy}", 
    #                 f"task.mcts.expansion_default_action_list={expansion_default_action_list}", 
    #                 f"task.mcts.expansion_enable_importance_sampling={expansion_enable_importance_sampling}", 
    #                 f"task.mcts.expansion_importance_sampling_J_star_scaling_factor={expansion_importance_sampling_J_star_scaling_factor}", 
    #                 f"task.mcts.expansion_importance_sampling_eps={expansion_importance_sampling_eps}", 
    #                 f"task.mcts.expansion_importance_sampling_verbose={expansion_importance_sampling_verbose}", 
    #                 f"task.mcts.num_per_iteration_selection={num_per_iteration_selection}", 
    #                 f"task.mcts.per_iteration_expansion_lim={per_iteration_expansion_lim}", 
    #                 f"task.mcts.simulation_action_sampling_policy={simulation_action_sampling_policy}", 
    #                 f"task.mcts.simulation_default_action_list={simulation_default_action_list}", 
    #                 f"task.mcts.nfe_cal_dynamics_lim={nfe_cal_dynamics_lim}", 
    #                 f"task.mcts.nfe_cal_intermediate_reward_lim={nfe_cal_intermediate_reward_lim}", 
    #                 f"task.mcts.nfe_cal_final_reward_lim={nfe_cal_final_reward_lim}", 

    #                 f"task.mcts.optimal_control_online_update={optimal_control_online_update}", 
    #                 f"task.mcts.optimal_control_update_reward_threshold={optimal_control_update_reward_threshold}", 
    #                 f"task.mcts.optimal_control_omega_z={optimal_control_omega_z}", 
    #                 f"task.mcts.optimal_control_omega_eta={optimal_control_omega_eta}", 
    #                 f"task.mcts.optimal_control_finite_difference_accuracy_order={optimal_control_finite_difference_accuracy_order}", 
    #                 f"task.mcts.optimal_control_finite_difference_eps={optimal_control_finite_difference_eps}", 
    #                 f"task.mcts.optimal_control_force_positive_semi_definite_max_tolerance={optimal_control_force_positive_semi_definite_max_tolerance}", 
    #                 f"task.mcts.optimal_control_force_positive_definite_max_tolerance={optimal_control_force_positive_definite_max_tolerance}", 
                     
    #                 f"task.mcts.beta_parameterization={beta_parameterization}", 
    #                 f"task.mcts.beta_online_update={beta_online_update}", 
    #                 f"task.mcts.beta_zeta_list={beta_zeta_list}", 
    #                 f"task.mcts.beta_update_reward_threshold={beta_update_reward_threshold}", 
    #                 f"task.mcts.beta_clamp_eps={beta_clamp_eps}",

    #                 # display
    #                 f"task.display.display_trajectory={display_trajectory}", 
    #                 f"task.display.display_selected_node_depth={display_selected_node_depth}", 
    #                 f"task.display.display_cal_state_value={display_cal_state_value}", 
    #                 f"task.display.display_reward_sum_to_leaf={display_reward_sum_to_leaf}"
    #             ]

    #             logger(f"    cmd: {cmd}")

    #             # subprocess.run(cmd)

    #             executor.submit(
    #                 subprocess.run, 

    #                 cmd
    #             )

    #             time.sleep(1.0)

    #             # goto `for (prompt_idx, prompt)`
    #             pass

    #     # ---------= [Clean Up] =---------
    #     del batch_prompt_tuple_list
    #     gc.collect()

    #     # `implement_batch()` done
    #     pass

    # for batch_idx in range(num_batch):
    #     # ---------= [Sample Batch] =---------
    #     if (batch_idx < num_batch - 1) or (num_sample % batch_size == 0):
    #         true_batch_size = batch_size
    #     else:
    #         true_batch_size = num_sample % batch_size
        
    #     sample_idx_st = batch_idx * batch_size

    #     # ---------= [Prepare Init Latent Seed List] =---------
    #     arg_dict = {}

    #     init_latent_seed_list_ = []

    #     if init_latent_random:
    #         init_latent_seed_list_ = [
    #             random.randint(init_latent_seed_st, init_latent_seed_ed) \
    #                 for _ in range(true_batch_size)
    #         ]
    #     else:
    #         for sample_idx in range(true_batch_size):
    #             true_sample_idx = sample_idx_st + sample_idx
    #             prompt_true_sample_idx = true_sample_idx % num_sample_per_prompt

    #             init_latent_seed_list_.append(
    #                 init_latent_seed_list[prompt_true_sample_idx]
    #             )

    #             # `for sample_idx` done
    #             pass

    #     arg_dict["init_latent_seed_list"] = init_latent_seed_list_

    #     # ---------= [Prepare Eps Seed List List] =---------
    #     eps_seed_list_list = []

    #     for sample_idx in range(true_batch_size):
    #         if eps_random:
    #             eps_seed_list = [
    #                 random.randint(eps_seed_st, eps_seed_ed) \
    #                     for _ in range(num_inference_step)
    #             ]

    #             eps_seed_list_list.append(eps_seed_list)
    #         else:
    #             eps_seed_list_list.append(eps_seed_list[:])

    #             # `for sample_idx` done
    #             pass
        
    #     eps_seed_list_list = torch.tensor(eps_seed_list_list)

    #     arg_dict["eps_seed_list_list"] = eps_seed_list_list

    #     # ---------= [Do MCTS] =---------
    #     implement_batch(
    #         batch_idx = batch_idx, 
    #         true_batch_size = true_batch_size, 

    #         arg_dict = arg_dict
    #     )

    #     diffusion_mcts.run(
    #         display_result = True, 

    #         display_trajectory = display_trajectory, 
    #         display_state = False, 
    #         display_action = True, 
    #         display_reward = True, 

    #         display_cal_state_value = display_cal_state_value, 
    #         display_selected_node_depth = display_selected_node_depth, 
    #         display_reward_sum_to_leaf = display_reward_sum_to_leaf
    #     )

    #     # ---------= [Clean Up] =---------
    #     del arg_dict["init_latent_seed_list"]
    #     del arg_dict["eps_seed_list_list"]
    #     del arg_dict
    #     gc.collect()
    #     torch.cuda.empty_cache()

    #     # goto `for batch_idx`
    #     pass

    # `run_optimal_control_mcts_implement()` done
    pass


def run_optimal_control_mcts(
    cfg: DictConfig
):
    run_optimal_control_mcts_implement(cfg)

    # `run_optimal_control_mcts()` done
    pass
