from util.logger import logger

from typing import List

from omegaconf import DictConfig

import time

import concurrent.futures as cf

import shlex
import subprocess

import gc

from pathlib import Path

from tqdm.auto import tqdm

import numpy as np

import torch

import random

from util.basic_util import (
    get_global_variable, 
    is_none, 
    get_true_value
)
from util.image_util import save_pil_as_png
from util.yaml_util import (
    load_yaml, 
    save_yaml, 
    convert_numpy_type_to_native_type
)
from util.torch_util import get_latent
from util.pipeline_util import (
    load_pipeline, load_scheduler, 
    get_inference_step_minus_one, 
    get_pipeline_category_and_type, 
    img_latent_to_pil, 
    get_folder_name
)

from prompt_manager.util import get_prompt_manager

from task.util.seed_list_util import prepare_seed_list

from my_diffusers.scheduling_ddim import register_scheduling_ddim

from OT_MCTS.src.lru_cache import LRUCache

from diffusion_OT_MCTS.mdp.state_space import StateSpace
from diffusion_OT_MCTS.mdp.action_space_eta import ActionSpaceEta
from diffusion_OT_MCTS.mdp.action_space_eps import ActionSpaceEps
from diffusion_OT_MCTS.mdp.diffusion_mdp import DiffusionMDP
from diffusion_OT_MCTS.reward_model.util import get_reward_model
from diffusion_OT_MCTS.diffusion_ot_bs import DiffusionOTBS


@torch.no_grad()
def run_optimal_control_bs_implement(
    cfg: DictConfig
):
    cfg_dict = {}

    # ---------= [Basic Global Variables] =---------
    exp_name = get_global_variable("exp_name")
    start_time = get_global_variable("start_time")
    device = get_global_variable("device")
    seed = get_global_variable("seed")
    exp_time_str = f"{exp_name}_{start_time}"

    concurrent_max_worker = get_global_variable("concurrent_max_worker")

    vae_decode_batch_size = get_global_variable("vae_decode_batch_size")

    cfg_dict["global_variable"] = {
        "exp_name": exp_name, 
        "start_time": start_time, 
        "device": device, 
        "seed": seed, 
        "exp_time_str": exp_time_str, 
        "concurrent_max_worker": concurrent_max_worker, 
        "vae_decode_batch_size": vae_decode_batch_size
    }

    # ---------= [Pipeline] =---------
    logger(f"[Pipeline] Loading started. ")

    pipeline_type = get_true_value(cfg["pipeline"]["pipeline_type"])
    pipeline_path = get_true_value(cfg["pipeline"]["pipeline_path"])
    pipeline_torch_dtype = get_true_value(cfg["pipeline"]["torch_dtype"])
    pipeline_variant = get_true_value(cfg["pipeline"]["variant"])
    
    logger(f"    pipeline_type: {pipeline_type}")
    logger(f"    pipeline_path: {pipeline_path}")
    logger(f"    pipeline_torch_dtype: {pipeline_torch_dtype}")
    logger(f"    pipeline_variant: {pipeline_variant}")

    scheduler_type = get_true_value(cfg["pipeline"]["scheduler"])

    logger(f"    scheduler_type: {scheduler_type}")

    cfg_dict["pipeline"] = {
        "pipeline_type": pipeline_type, 
        "pipeline_path": pipeline_path, 
        "pipeline_torch_dtype": pipeline_torch_dtype, 
        "pipeline_variant": pipeline_variant, 

        "scheduler_type": scheduler_type
    },

    logger(
        f"[Pipeline] Loading finished. "
        "\n"
    )

    # ---------= [Eps Action] =---------
    task_name = get_true_value(cfg["task"]["name"])

    is_eps_action = True if task_name.startswith("search-run_optimal_control_bs_eps") \
        else False

    # ---------= [Prompt List] =---------
    logger(f"[Prompt List] Loading started. ")

    num_prompt = get_true_value(cfg["task"]["prompt_list"]["num_prompt"])
    prompt_manager_dict = get_true_value(cfg["task"]["prompt_list"]["prompt_manager_dict"])

    logger(f"    num_prompt: {num_prompt}")
    logger(f"    prompt_manager_dict: {prompt_manager_dict}")

    logger(
        f"[Prompt List] Loading finished. "
        "\n"
    )

    # ---------= [Init Latent] =---------
    logger(f"[Init Latent] Loading started. ")

    init_latent_random = get_true_value(cfg["task"]["init_latent"]["random"])

    logger(f"    init_latent_random: {init_latent_random}")

    cfg_dict["init_latent"] = {
        "init_latent_random": init_latent_random
    }

    if init_latent_random:
        init_latent_seed_st = get_true_value(cfg["task"]["init_latent"]["seed_st"])
        init_latent_seed_ed = get_true_value(cfg["task"]["init_latent"]["seed_ed"])
        
        logger(f"    init_latent_seed_st: {init_latent_seed_st}")
        logger(f"    init_latent_seed_ed: {init_latent_seed_ed}")

        cfg_dict["init_latent"] = {
            "init_latent_seed_st": init_latent_seed_st, 
            "init_latent_seed_ed": init_latent_seed_ed
        }
    else:
        init_latent_seed_list = get_true_value(cfg["task"]["init_latent"]["seed_list"])
        init_latent_seed_auto_increment = get_true_value(cfg["task"]["init_latent"]["seed_auto_increment"])
        
        logger(f"    init_latent_seed_list: {init_latent_seed_list}")
        logger(f"    init_latent_seed_auto_increment: {init_latent_seed_auto_increment}")

        cfg_dict["init_latent"] = {
            "init_latent_seed_list": init_latent_seed_list, 
            "init_latent_seed_auto_increment": init_latent_seed_auto_increment
        }

    logger(
        f"[Init Latent] Loading finished. "
        "\n"
    )

    # ---------= [Eps] =---------
    logger(f"[Eps] Loading started. ")

    eps_random = get_true_value(cfg["task"]["eps"]["random"])

    logger(f"    eps_random: {eps_random}")

    cfg_dict["eps"] = {
        "eps_random": eps_random
    }

    if eps_random:
        eps_seed_st = get_true_value(cfg["task"]["eps"]["seed_st"])
        eps_seed_ed = get_true_value(cfg["task"]["eps"]["seed_ed"])
        
        logger(f"    eps_seed_st: {eps_seed_st}")
        logger(f"    eps_seed_ed: {eps_seed_ed}")

        cfg_dict["eps"] = {
            "eps_seed_st": eps_seed_st, 
            "eps_seed_ed": eps_seed_ed
        }
    else:
        eps_seed_list = get_true_value(cfg["task"]["eps"]["seed_list"])
        eps_seed_auto_increment = get_true_value(cfg["task"]["eps"]["seed_auto_increment"])
        
        logger(f"    eps_seed_list: {eps_seed_list}")
        logger(f"    eps_seed_auto_increment: {eps_seed_auto_increment}")

        cfg_dict["eps"] = {
            "eps_seed_list": eps_seed_list, 
            "eps_seed_auto_increment": eps_seed_auto_increment
        }

    logger(
        f"[Eps] Loading finished. "
        "\n"
    )

    # ---------= [Eta] =---------
    logger(f"[Eta] Loading started. ")

    eta_random = get_true_value(cfg["task"]["eta"]["random"])

    logger(f"    eta_random: {eta_random}")
    
    eta_low = get_true_value(cfg["task"]["eta"]["eta_low"])
    eta_high = get_true_value(cfg["task"]["eta"]["eta_high"])

    logger(f"    eta_low: {eta_low}")
    logger(f"    eta_high: {eta_high}")

    default_eta = None
    if not eta_random:
        default_eta = get_true_value(cfg["task"]["eta"]["default_eta"])

    logger(
        f"[Eta] Loading finished. "
        "\n"
    )

    cfg_dict["eta"] = {
        "eta_low": eta_low, 
        "eta_high": eta_high, 

        "default_eta": default_eta
    }

    # ---------= [Sample] =---------
    logger(f"[Sample] Loading started. ")

    prompt_2 = get_true_value(cfg["task"]["sample"]["prompt_2"])
    negative_prompt = get_true_value(cfg["task"]["sample"]["negative_prompt"])
    negative_prompt_2 = get_true_value(cfg["task"]["sample"]["negative_prompt_2"])

    logger(f"    prompt_2: {prompt_2}")
    logger(f"    negative_prompt: {negative_prompt}")
    logger(f"    negative_prompt_2: {negative_prompt_2}")

    height = get_true_value(cfg["task"]["sample"]["height"])
    width = get_true_value(cfg["task"]["sample"]["width"])
    down_sampling_ratio = get_true_value(cfg["task"]["sample"]["down_sampling_ratio"])

    logger(f"    height: {height}")
    logger(f"    width: {width}")
    logger(f"    down_sampling_ratio: {down_sampling_ratio}")

    num_inference_step = get_true_value(cfg["task"]["sample"]["num_inference_step"])
    guidance_scale = get_true_value(cfg["task"]["sample"]["guidance_scale"])

    logger(f"    num_inference_step: {num_inference_step}")
    logger(f"    guidance_scale: {guidance_scale}")

    logger(
        f"[Sample] Loading finished. "
        "\n"
    )

    cfg_dict["sample"] = {
        "prompt_2": prompt_2, 
        "negative_prompt": negative_prompt, 
        "negative_prompt_2": negative_prompt_2, 

        "height": height, 
        "width": width, 
        "down_sampling_ratio": down_sampling_ratio, 

        "num_inference_step": num_inference_step, 

        "guidance_scale": guidance_scale
    }

    # ---------= [Task] =---------
    logger(f"[Task] Loading started. ")
    
    num_sample_per_prompt = get_true_value(cfg["task"]["task"]["num_sample_per_prompt"])

    logger(f"    num_sample_per_prompt: {num_sample_per_prompt}")

    logger(
        f"[Task] Loading finished. "
        "\n"
    )

    cfg_dict["task"] = {
        "num_sample_per_prompt": num_sample_per_prompt
    }

    # ---------= [Promptist] =---------
    logger(f"[Promptist] Loading started. ")
    
    enable_promptist = get_true_value(cfg["task"]["promptist"]["enable"])
    cfg_yaml_path_promptist = get_true_value(cfg["task"]["promptist"]["cfg_yaml_path"])

    logger(f"    enable_promptist: {enable_promptist}")
    logger(f"    cfg_yaml_path_promptist: {cfg_yaml_path_promptist}")

    logger(
        f"[Promptist] Loading finished. "
        "\n"
    )

    cfg_dict["promptist"] = {
        "enable": enable_promptist, 
        "cfg_yaml_path": cfg_yaml_path_promptist
    }

    # ---------= [Golden Noise] =---------
    logger(f"[Golden Noise] Loading started. ")
    
    enable_golden_noise = get_true_value(cfg["task"]["golden_noise"]["enable"])
    cfg_yaml_path_golden_noise = get_true_value(cfg["task"]["golden_noise"]["cfg_yaml_path"])

    logger(f"    enable_golden_noise: {enable_golden_noise}")
    logger(f"    cfg_yaml_path_golden_noise: {cfg_yaml_path_golden_noise}")

    logger(
        f"[Golden Noise] Loading finished. "
        "\n"
    )

    cfg_dict["golden_noise"] = {
        "enable": enable_golden_noise, 
        "cfg_yaml_path": cfg_yaml_path_golden_noise
    }

    # ---------= [Save] =---------
    logger(f"[Save] Loading started. ")
    
    save_root_path = get_true_value(cfg["task"]["save"]["save_root_path"])

    logger(f"    save_root_path: {save_root_path}")

    logger(
        f"[Save] Loading finished. "
        "\n"
    )

    cfg_dict["save"] = {
        "save_root_path": save_root_path
    }

    # ---------= [Reward Model] =---------
    logger(f"[Reward Model] Loading started. ")
    
    reward_model_type = get_true_value(cfg["task"]["reward_model"]["reward_model_type"])
    cal_dynamics_batch_size = get_true_value(cfg["task"]["reward_model"]["cal_dynamics_batch_size"])
    cal_intermediate_reward_batch_size = get_true_value(cfg["task"]["reward_model"]["cal_intermediate_reward_batch_size"])
    cal_final_reward_batch_size = get_true_value(cfg["task"]["reward_model"]["cal_final_reward_batch_size"])
    cal_intermediate_reward_policy = get_true_value(cfg["task"]["reward_model"]["cal_intermediate_reward_policy"])
    reward_shaping_policy = get_true_value(cfg["task"]["reward_model"]["reward_shaping_policy"])

    logger(f"    reward_model_type: {reward_model_type}")
    logger(f"    cal_dynamics_batch_size: {cal_dynamics_batch_size}")
    logger(f"    cal_intermediate_reward_batch_size: {cal_intermediate_reward_batch_size}")
    logger(f"    cal_final_reward_batch_size: {cal_final_reward_batch_size}")
    logger(f"    cal_intermediate_reward_policy: {cal_intermediate_reward_policy}")
    logger(f"    reward_shaping_policy: {reward_shaping_policy}")

    cfg_dict["reward_model"] = {
        "reward_model_type": reward_model_type, 
        "cal_dynamics_batch_size": cal_dynamics_batch_size, 
        "cal_intermediate_reward_batch_size": cal_intermediate_reward_batch_size, 
        "cal_final_reward_batch_size": cal_final_reward_batch_size, 
        "cal_intermediate_reward_policy": cal_intermediate_reward_policy, 
        "reward_shaping_policy": reward_shaping_policy
    }

    num_look_ahead_step = None
    gamma = None

    if cal_intermediate_reward_policy == "look_ahead":
        num_look_ahead_step = get_true_value(cfg["task"]["reward_model"]["num_look_ahead_step"])

        logger(f"    num_look_ahead_step: {num_look_ahead_step}")

        cfg_dict["reward_model"]["num_look_ahead_step"] = num_look_ahead_step
    elif cal_intermediate_reward_policy == "discount":
        gamma = get_true_value(cfg["task"]["reward_model"]["gamma"])

        logger(f"    gamma: {gamma}")

        cfg_dict["reward_model"]["gamma"] = gamma

    logger(
        f"[Reward Model] Loading finished. "
        "\n"
    )

    # ---------= [LRU Cache] =---------
    logger(f"[LRU Cache] Loading started. ")
    
    num_gpu_resident_lim = get_true_value(cfg["task"]["lru_cache"]["num_gpu_resident_lim"])

    logger(f"    num_gpu_resident_lim: {num_gpu_resident_lim}")

    logger(
        f"[LRU Cache] Loading finished. "
        "\n"
    )

    cfg_dict["lru_cache"] = {
        "num_gpu_resident_lim": num_gpu_resident_lim
    }

    # ---------= [Beam Search] =---------
    logger(f"[Beam Search] Loading started. ")

        # ---------= [Beam Search] =---------
    num_beam = get_true_value(cfg["task"]["beam_search"]["beam_search"]["num_beam"])
    num_candidate_per_beam = get_true_value(cfg["task"]["beam_search"]["beam_search"]["num_candidate_per_beam"])

    logger(f"    num_beam: {num_beam}")
    logger(f"    num_candidate_per_beam: {num_candidate_per_beam}")

        # ---------= [MDP Modeling] =---------
    mdp_modeling = get_true_value(cfg["task"]["beam_search"]["mdp_modeling"])

    logger(f"    mdp_modeling: {mdp_modeling}")

        # ---------= [Expansion Policy] =---------
    expansion_action_sampling_policy = get_true_value(cfg["task"]["beam_search"]["expansion"]["expansion_action_sampling_policy"])

    logger(f"    expansion_action_sampling_policy: {expansion_action_sampling_policy}")
    
        # ---------= [NFE Limit] =---------
    nfe_cal_dynamics_lim = get_true_value(cfg["task"]["beam_search"]["nfe_limit"]["nfe_cal_dynamics_lim"])
    nfe_cal_intermediate_reward_lim = get_true_value(cfg["task"]["beam_search"]["nfe_limit"]["nfe_cal_intermediate_reward_lim"])
    nfe_cal_final_reward_lim = get_true_value(cfg["task"]["beam_search"]["nfe_limit"]["nfe_cal_final_reward_lim"])
    
    logger(f"    nfe_cal_dynamics_lim: {nfe_cal_dynamics_lim}")
    logger(f"    nfe_cal_intermediate_reward_lim: {nfe_cal_intermediate_reward_lim}")
    logger(f"    nfe_cal_final_reward_lim: {nfe_cal_final_reward_lim}")

    logger(
        f"[Beam Search] Loading finished. "
        "\n"
    )

    cfg_dict["beam_search"] = {
        "beam_search": {
            "num_beam": num_beam, 
            "num_candidate_per_beam": num_candidate_per_beam
        }, 

        "mdp_modeling": mdp_modeling, 

        "expansion": {
            "expansion_action_sampling_policy": expansion_action_sampling_policy
        }, 

        "nfe_limit": {
            "nfe_cal_dynamics_lim": nfe_cal_dynamics_lim, 
            "nfe_cal_intermediate_reward_lim": nfe_cal_intermediate_reward_lim, 
            "nfe_cal_final_reward_lim": nfe_cal_final_reward_lim
        }
    }

    # ---------= [Display] =---------
    logger(f"[Display] Loading started. ")
    
    display_trajectory = get_true_value(cfg["task"]["display"]["display_trajectory"])
    display_selected_node_depth = get_true_value(cfg["task"]["display"]["display_selected_node_depth"])
    display_cal_state_value = get_true_value(cfg["task"]["display"]["display_cal_state_value"])
    display_reward_sum_to_leaf = get_true_value(cfg["task"]["display"]["display_reward_sum_to_leaf"])
    display_beta_mode_update = get_true_value(cfg["task"]["display"]["display_beta_mode_update"])

    logger(f"    display_trajectory: {display_trajectory}")
    logger(f"    display_selected_node_depth: {display_selected_node_depth}")
    logger(f"    display_cal_state_value: {display_cal_state_value}")
    logger(f"    display_reward_sum_to_leaf: {display_reward_sum_to_leaf}")
    logger(f"    display_beta_mode_update: {display_beta_mode_update}")

    logger(
        f"[Display] Loading finished. "
        "\n"
    )

    cfg_dict["display"] = {
        "display_trajectory": display_trajectory, 
        "display_selected_node_depth": display_selected_node_depth, 
        "display_cal_state_value": display_cal_state_value, 
        "display_reward_sum_to_leaf": display_reward_sum_to_leaf, 
        "display_beta_mode_update": display_beta_mode_update
    }

    # ---------= [All Components Loaded] =---------
    logger(
        f"All components loaded. "
        "\n"
    )

    # ---------= [Check Input] =---------
    if reward_shaping_policy != "latent_reward":
        logger(
            f"Only support `reward_shaping_policy = 'latent_reward'` in beam search. ", 

            log_type = "error"
        )

        breakpoint()

    if is_eps_action:
        if expansion_action_sampling_policy != "uniform":
            logger(
                f"Only support `expansion_action_sampling_policy = 'uniform'` in beam search. ", 

                log_type = "error"
            )

            breakpoint()

    # ---------= [Prepare Pipeline] =---------
    (
        pipeline_category_name, 
        pipeline_type_name
    ) = get_pipeline_category_and_type(
        pipeline_path = pipeline_path
    )

    pipeline = load_pipeline(
        pipeline_type = pipeline_type, 
        pipeline_path = pipeline_path, 
        torch_dtype = pipeline_torch_dtype, 
        variant = pipeline_variant
    )

    pipeline.scheduler = load_scheduler(
        pipeline = pipeline, 
        scheduler_type = scheduler_type
    )
    
    inference_step_minus_one = get_inference_step_minus_one(scheduler_type)

    if pipeline_category_name == "sd_family":
        from my_diffusers.pipeline_stable_diffusion import register_pipeline_stable_diffusion

        # register custom methods: `forward()`, `prepare_everything()`, `step()`
        register_pipeline_stable_diffusion(pipeline)
    elif pipeline_category_name == "sdxl_family":
        from my_diffusers.pipeline_stable_diffusion_xl import register_pipeline_stable_diffusion_xl

        # register custom methods: `forward()`, `prepare_everything()`, `step()`
        register_pipeline_stable_diffusion_xl(pipeline)
    elif pipeline_category_name == "pixart_alpha_family":
        from my_diffusers.pipeline_pixart_alpha import register_pipeline_pixart_alpha

        # register custom methods: `forward()`, `prepare_everything()`, `get_noise_pred()`, `step()`
        register_pipeline_pixart_alpha(pipeline)
    else:
        raise NotImplementedError(
            f"Unsupported `pipeline_category_name`, got `{pipeline_category_name}`. "
        )
    
    register_scheduling_ddim(
        scheduler = pipeline.scheduler
    )

    # move to GPU after registration
    pipeline = pipeline.to(device)

    # save VRAM by offloading the model to CPU
    pipeline.enable_model_cpu_offload()
    
    logger(f"    pipeline: {type(pipeline)}")
    # logger(f"    pipeline: {pipeline}")

    cfg_dict["pipeline"] = {
        "pipeline_category_name": pipeline_category_name, 
        "pipeline_type_name": pipeline_type_name, 
        "inference_step_minus_one": inference_step_minus_one
    }

    # ---------= [Prompt List] =---------
    prompt_manager = get_prompt_manager(
        prompt_manager_dict = prompt_manager_dict
    )

    prompt_manager.load_prompt_list()
    prompt_manager.prepare_everything(
        shuffle = True
    )

    prompt_list = list(prompt_manager.prompt_list)
    folder_name_list = prompt_manager.folder_name_list
    
    if num_prompt is not None:
        prompt_list = prompt_list[: num_prompt]
        folder_name_list = folder_name_list[: num_prompt]
    else:
        num_prompt = len(prompt_list)
    
    logger(f"prompt_list: {prompt_list}")

    # ---------= [Promptist] =---------
    optimized_prompt_list = None

    if enable_promptist:
        from ..util.promptist_util import (
            load_promptist_model, 
            optimize_text_prompt
        )

        (
            promptist_model, 
            promptist_tokenizer
        ) = load_promptist_model(
            cfg_yaml_path = cfg_yaml_path_promptist, 

            device = device
        )

        optimized_prompt_list = [
            optimize_text_prompt(
                prompt = prompt, 

                promptist_model = promptist_model, 
                promptist_tokenizer = promptist_tokenizer
            ) \
                for prompt in prompt_list
        ]

        # NB: 
        # `folder_name_list` is still generated with the original prompts

        # ---------= [Clean Up] =---------
        del promptist_model, promptist_tokenizer
        torch.cuda.empty_cache()
        gc.collect()

        pass

    # ---------= [Prepare Task] =---------
    num_sample = num_prompt * num_sample_per_prompt

    cfg_dict["task"] = {
        "num_prompt": num_prompt, 
        "num_sample": num_sample
    }

    # ---------= [Prepare Init Latent] =---------
    latent_num_channel = 4
    latent_height = height // down_sampling_ratio
    latent_width = width // down_sampling_ratio

    latent_shape = (latent_num_channel, latent_height, latent_width)

    if not init_latent_random:
        init_latent_seed_list = prepare_seed_list(
            seed_list = init_latent_seed_list, 
            target_length = num_sample_per_prompt, 

            auto_inrement = init_latent_seed_auto_increment
        )

        cfg_dict["init_latent"]["init_latent_seed_list"] = init_latent_seed_list
        
        init_latent_list = [
            get_latent(
                shape = (latent_num_channel, latent_height, latent_width), 
                
                seed = init_latent_seed_list[sample_idx % num_sample_per_prompt], 

                device = "cpu", 

                dtype = pipeline_torch_dtype
            ) \
                for sample_idx in range(num_sample)
        ]
    else:
        raise NotImplementedError(
            f"Only support `init_latent_random = False`. "
        )

    # init_latent_list.shape = (num_sample, latent_num_channel, latent_height, latent_width)
    init_latent_list = torch.stack(init_latent_list)

    # ---------= [Golden Noise] =---------
    if enable_golden_noise:
        from ..util.golden_noise.util import (
            load_golden_noise_model, 
            optimize_initial_noise
        )

        if pipeline_category_name == "sdxl_family":
            npnet = load_golden_noise_model(
                cfg_yaml_path = cfg_yaml_path_golden_noise, 

                sd_type = "SDXL", 

                device = device
            )
        else:
            raise NotImplementedError(
                f"Only support SDXL pipeline. "
            )

        if enable_promptist:
            duplicated_prompt_list = [
                optimized_prompt_list[sample_idx % num_sample_per_prompt] \
                    for sample_idx in range(num_sample)
            ]
        else:
            duplicated_prompt_list = [
                prompt_list[sample_idx % num_sample_per_prompt] \
                    for sample_idx in range(num_sample)
            ]

        optimized_init_latent_list = optimize_initial_noise(
            initial_noise_list = init_latent_list, 

            npnet = npnet, 

            prompt_list = duplicated_prompt_list, 

            pipeline = pipeline
        )

        init_latent_list = init_latent_list.cpu()

        # ---------= [Clean Up] =---------
        del duplicated_prompt_list
        del npnet
        torch.cuda.empty_cache()
        gc.collect()

        pass

    # ---------= [Prepare Eps] =---------
    eps_list = None
    if (not is_eps_action) and (not eps_random):
        eps_seed_list = prepare_seed_list(
            seed_list = eps_seed_list, 
            target_length = num_inference_step, 

            auto_inrement = eps_seed_auto_increment
        )

        cfg_dict["eps"]["eps_seed_list"] = eps_seed_list

        eps_list = [
            get_latent(
                shape = (latent_num_channel, latent_height, latent_width), 
                
                seed = seed, 

                device = device, 

                dtype = pipeline_torch_dtype
            ) \
                for seed in eps_seed_list
        ]

    # ---------= [Prepare Save Sample] =---------
    save_root_path = Path(save_root_path)
    pipeline_root_path = save_root_path / pipeline_type_name

    dataset_type_str = prompt_manager_dict["prompt_manager_type"]
    dataset_root_path = pipeline_root_path / dataset_type_str

    reward_model_type_str = reward_model_type
    if reward_model_type == "color_channel_reward":
        reward_model_cfg_yaml_path = Path("./config/reward_model/color_channel_reward.yaml")
        reward_model_cfg_dict = load_yaml(reward_model_cfg_yaml_path)
        target_channel_idx = reward_model_cfg_dict["color_channel_reward"]["target_channel_idx"]
        reward_model_type_str = f"{reward_model_type_str}-{target_channel_idx}"
    elif reward_model_type == "laplacian_var_reward":
        reward_model_cfg_yaml_path = Path("./config/reward_model/laplacian_var_reward.yaml")
        reward_model_cfg_dict = load_yaml(reward_model_cfg_yaml_path)
        tsfm_to_gray = reward_model_cfg_dict["laplacian_var_reward"]["tsfm_to_gray"]
        if tsfm_to_gray:
            reward_model_type_str = f"{reward_model_type_str}-gray"
        else:
            reward_model_type_str = f"{reward_model_type_str}-rgb"
    elif reward_model_type == "compressibility_hps_v2":
        reward_model_cfg_yaml_path = Path("./config/reward_model/compressibility_hps_v2.yaml")
        reward_model_cfg_dict = load_yaml(reward_model_cfg_yaml_path)
        lam = reward_model_cfg_dict["compressibility_hps_v2"]["lam"]
        reward_model_type_str = f"{reward_model_type_str}-{lam}"

    reward_model_root_path = dataset_root_path / reward_model_type_str
    reward_shaping_policy_root_path = reward_model_root_path / reward_shaping_policy

    mdp_modeling_root_path = reward_shaping_policy_root_path / mdp_modeling

    inference_step_root_path = mdp_modeling_root_path / f"{num_inference_step}"

    exp_root_path = inference_step_root_path / exp_time_str

    cfg_dict["save"]["exp_root_path"] = exp_root_path

    # ---------= [Prepare MDP] =---------
    action_shape = (1, )

    state_space = StateSpace(
        shape = latent_shape, 

        dtype = pipeline_torch_dtype, 
        device = device, 

        ver = "torch"
    )

    action_space_eta = None
    if (not is_eps_action) or eta_random:
        action_space_eta = ActionSpaceEta(
            eta_low = eta_low, 
            eta_high = eta_high, 

            shape = action_shape, 

            dtype = pipeline_torch_dtype, 
            device = device, 

            ver = "torch"
        )

    action_space_eps = None
    if is_eps_action or eps_random:
        action_space_eps = ActionSpaceEps(
            eps_seed_low = eps_seed_st, 
            eps_seed_high = eps_seed_ed, 

            shape = action_shape, 

            dtype = pipeline_torch_dtype, 
            device = device, 

            ver = "torch"
        )

    reward_shape = (1, )

    diffusion_mdp = DiffusionMDP(
        # ---------= [Beam Search] =---------
        is_eps_action = is_eps_action, 
        
        # ---------= [Eta] =---------
        eta_random = eta_random, 
        default_eta = default_eta, 

        state_space = state_space, 
        action_space_eta = action_space_eta, 
        action_space_eps = action_space_eps, 
        time_horizon = num_inference_step, 

        # ---------= [Reward] =---------
        reward_shape = reward_shape, 

        # ---------= [Pipeline] =---------
        pipeline = pipeline, 
        eps_list = eps_list, 

        cal_intermediate_reward_policy = cal_intermediate_reward_policy
    )

    cfg_dict["eta"]["action_shape"] = action_shape
    cfg_dict["reward_model"]["reward_shape"] = reward_shape

    # ---------= [Prepare Everything] =---------
    diffusion_mdp.prepare_everything(
        prompt_list = prompt_list, 
        optimized_prompt_list = optimized_prompt_list, 

        prompt_2 = prompt_2, 
        negative_prompt = negative_prompt, 
        negative_prompt_2 = negative_prompt_2, 

        height = height, width = width, 
        guidance_scale = guidance_scale, 

        init_latent_list = init_latent_list, 

        num_sample_per_prompt = num_sample_per_prompt, 

        inference_step_minus_one = inference_step_minus_one
    )

    # init_latent_list.shape = (num_sample_per_prompt, latent_num_channel, latent_height, latent_width)
    init_latent_list = diffusion_mdp.init_latent_list

    # ---------= [Prepare Reward Model] =---------
    prompt_emb_list = diffusion_mdp.prompt_emb_list
    param_dict = diffusion_mdp.param_dict
    
    reward_model = get_reward_model(
        reward_model_type = reward_model_type, 

        # ---------= [Pipeline] =---------
        pipeline = pipeline, 
        num_inference_step = num_inference_step, 

        # ---------= [Param] =---------
        prompt_emb_list = prompt_emb_list, 
        param_dict = param_dict, 
        num_sample_per_prompt = num_sample_per_prompt, 

        # ---------= [Reward] =---------
        reward_shape = reward_shape, 
        reward_dtype = pipeline_torch_dtype, 
        offload_to_cpu = True, 

        # ---------= [Parallel] =---------
        cal_dynamics_batch_size = cal_dynamics_batch_size, 
        cal_intermediate_reward_batch_size = cal_intermediate_reward_batch_size, 
        cal_final_reward_batch_size = cal_final_reward_batch_size, 

        # ---------= [Reward Shaping] =---------
        reward_shaping_policy = reward_shaping_policy, 
        cal_intermediate_reward_policy = cal_intermediate_reward_policy, 

        device = device, 
        
        vae_decode_batch_size = vae_decode_batch_size
    )
    
    cal_intermediate_reward_arg_dict = {}

    if cal_intermediate_reward_policy == "look_ahead":
        cal_intermediate_reward_arg_dict["num_look_ahead_step"] = num_look_ahead_step
    elif cal_intermediate_reward_policy == "discount":
        cal_intermediate_reward_arg_dict["gamma"] = gamma

    diffusion_mdp.set_reward_model(
        reward_model = reward_model, 

        **cal_intermediate_reward_arg_dict
    )

    if reward_model_type == "color_channel_reward":
        target_channel_idx = reward_model.target_channel_idx
        cfg_dict["reward_model"]["target_channel_idx"] = target_channel_idx

    # ---------= [Prepare LRU Cache] =---------
    lru_cache = LRUCache(
        num_gpu_resident_lim = num_gpu_resident_lim, 

        device = device
    )

    # ---------= [Diffusion BS] =---------
    arg_dict = {}

    if cal_intermediate_reward_policy == "look_ahead":
        arg_dict["num_look_ahead_step"] = num_look_ahead_step
    elif cal_intermediate_reward_policy == "discount":
        arg_dict["gamma"] = gamma

    diffusion_bs = DiffusionOTBS(
        is_eps_action = is_eps_action, 

        # ---------= [Beam Search] =---------
        num_beam = num_beam, 
        num_candidate_per_beam = num_candidate_per_beam, 

        mdp = diffusion_mdp, 
        init_state_list = init_latent_list, 

        # ---------= [Mode] =---------
        mdp_modeling = mdp_modeling, 

        # ---------= [Expansion Policy] =---------
        expansion_action_sampling_policy = expansion_action_sampling_policy, 

        # ---------= [NFE Limit] =---------
        nfe_cal_dynamics_lim = nfe_cal_dynamics_lim, 
        nfe_cal_intermediate_reward_lim = nfe_cal_intermediate_reward_lim, 
        nfe_cal_final_reward_lim = nfe_cal_final_reward_lim, 

        # ---------= [Optimal Control Beta] =---------
        lru_cache = lru_cache, 

        # ---------= [Dtype] =---------
        dtype = pipeline_torch_dtype, 

        # ---------= [Save Root Path] =---------
        expansion_policy_root_path = exp_root_path, 
        folder_name_list = folder_name_list, 
        cfg_dict = cfg_dict, 

        **arg_dict
    )

    # ---------= [Run Diffusion BS] =---------
    diffusion_bs.run(
        display_result = True, 

        display_trajectory = display_trajectory, 
        display_state = False, 
        display_action = True, 
        display_reward = True, 

        display_cal_state_value = display_cal_state_value, 
        display_selected_node_depth = display_selected_node_depth, 
        display_reward_sum_to_leaf = display_reward_sum_to_leaf
    )

    # `run_optimal_control_bss_implement()` done
    pass


def run_optimal_control_bs(
    cfg: DictConfig
):
    run_optimal_control_bs_implement(cfg)

    # `run_optimal_control_bs()` done
    pass
