from util.logger import logger

from typing import Optional, List, Dict

from omegaconf import DictConfig

import concurrent.futures as cf

from pathlib import Path

import random

import time

import gc

from tqdm.auto import tqdm

import numpy as np

import torch

from util.basic_util import (
    get_global_variable, 
    is_none, 
    get_true_value, 
    get_attr
)
from util.numpy_util import tsfm_to_1d_array
from util.pkl_util import load_pkl_cached
from util.image_util import save_pil_as_png
from util.yaml_util import (
    save_yaml, 
    convert_numpy_type_to_native_type
)
from util.torch_util import get_latent
from util.pipeline_util import (
    load_pipeline, load_scheduler, 
    get_inference_step_minus_one, 
    get_pipeline_category_and_type, 
    img_latent_to_pil, 
    get_folder_name
)

from task.util.seed_list_util import prepare_seed_list

from prompt_manager.util import get_prompt_manager

from my_diffusers.scheduling_ddim import register_scheduling_ddim


@torch.no_grad()
def run_sample_scheduled_implement(
    cfg: DictConfig
):
    # ---------= [Basic Global Variables] =---------
    exp_name = get_global_variable("exp_name")
    start_time = get_global_variable("start_time")
    device = get_global_variable("device")
    seed = get_global_variable("seed")
    exp_time_str = f"{exp_name}_{start_time}"

    concurrent_max_worker = get_global_variable("concurrent_max_worker")

    vae_decode_batch_size = get_global_variable("vae_decode_batch_size")

    # ---------= [Pipeline] =---------
    logger(f"[Pipeline] Loading started. ")

    pipeline_type = get_true_value(cfg["pipeline"]["pipeline_type"])
    pipeline_path = get_true_value(cfg["pipeline"]["pipeline_path"])
    pipeline_torch_dtype = get_true_value(cfg["pipeline"]["torch_dtype"])
    pipeline_variant = get_true_value(cfg["pipeline"]["variant"])
    
    logger(f"    pipeline_type: {pipeline_type}")
    logger(f"    pipeline_path: {pipeline_path}")
    logger(f"    pipeline_torch_dtype: {pipeline_torch_dtype}")
    logger(f"    pipeline_variant: {pipeline_variant}")

    scheduler_type = get_true_value(cfg["pipeline"]["scheduler"])

    logger(f"    scheduler_type: {scheduler_type}")

    logger(
        f"[Pipeline] Loading finished. "
        "\n"
    )

    # ---------= [Prompt List] =---------
    logger(f"[Prompt List] Loading started. ")

    num_prompt = get_true_value(cfg["task"]["prompt_list"]["num_prompt"])
    prompt_manager_dict = get_true_value(cfg["task"]["prompt_list"]["prompt_manager_dict"])

    logger(f"    num_prompt: {num_prompt}")
    logger(f"    prompt_manager_dict: {prompt_manager_dict}")

    logger(
        f"[Prompt List] Loading finished. "
        "\n"
    )

    # ---------= [Init Latent] =---------
    logger(f"[Init Latent] Loading started. ")

    init_latent_root_path = get_true_value(cfg["task"]["init_latent"]["root_path"])

    logger(f"    init_latent_root_path: {init_latent_root_path}")

    init_latent_random = get_true_value(cfg["task"]["init_latent"]["random"])

    logger(f"    init_latent_random: {init_latent_random}")

    if init_latent_random:
        init_latent_seed_st = get_true_value(cfg["task"]["init_latent"]["seed_st"])
        init_latent_seed_ed = get_true_value(cfg["task"]["init_latent"]["seed_ed"])
        
        logger(f"    init_latent_seed_st: {init_latent_seed_st}")
        logger(f"    init_latent_seed_ed: {init_latent_seed_ed}")
    else:
        init_latent_seed_list = get_true_value(cfg["task"]["init_latent"]["seed_list"])
        init_latent_seed_auto_increment = get_true_value(cfg["task"]["init_latent"]["seed_auto_increment"])
        
        logger(f"    init_latent_seed_list: {init_latent_seed_list}")
        logger(f"    init_latent_seed_auto_increment: {init_latent_seed_auto_increment}")

    logger(
        f"[Init Latent] Loading finished. "
        "\n"
    )

    # ---------= [Eps] =---------
    logger(f"[Eps] Loading started. ")

    eps_root_path = get_true_value(cfg["task"]["eps"]["root_path"])

    logger(f"    eps_root_path: {eps_root_path}")

    eps_random = get_true_value(cfg["task"]["eps"]["random"])

    logger(f"    eps_random: {eps_random}")

    if eps_random:
        eps_seed_st = get_true_value(cfg["task"]["eps"]["seed_st"])
        eps_seed_ed = get_true_value(cfg["task"]["eps"]["seed_ed"])
        
        logger(f"    eps_seed_st: {eps_seed_st}")
        logger(f"    eps_seed_ed: {eps_seed_ed}")
    else:
        eps_seed_list = get_true_value(cfg["task"]["eps"]["seed_list"])
        eps_seed_auto_increment = get_true_value(cfg["task"]["eps"]["seed_auto_increment"])
        
        logger(f"    eps_seed_list: {eps_seed_list}")
        logger(f"    eps_seed_auto_increment: {eps_seed_auto_increment}")

    logger(
        f"[Eps] Loading finished. "
        "\n"
    )

    # ---------= [Eta] =---------
    logger(f"[Eta] Loading started. ")

    eta_random = get_true_value(cfg["task"]["eta"]["random"])

    logger(f"    eta_random: {eta_random}")

    if eta_random:
        eta_st = get_true_value(cfg["task"]["eta"]["eta_st"])
        eta_ed = get_true_value(cfg["task"]["eta"]["eta_ed"])
        
        logger(f"    eta_st: {eta_st}")
        logger(f"    eta_ed: {eta_ed}")
    else:
        eta_list = get_true_value(cfg["task"]["eta"]["eta_list"])

        logger(f"    eta_list: {eta_list}")

    logger(
        f"[Eta] Loading finished. "
        "\n"
    )

    # ---------= [Sample] =---------
    logger(f"[Sample] Loading started. ")

    prompt_2 = get_true_value(cfg["task"]["sample"]["prompt_2"])
    negative_prompt = get_true_value(cfg["task"]["sample"]["negative_prompt"])
    negative_prompt_2 = get_true_value(cfg["task"]["sample"]["negative_prompt_2"])

    logger(f"    prompt_2: {prompt_2}")
    logger(f"    negative_prompt: {negative_prompt}")
    logger(f"    negative_prompt_2: {negative_prompt_2}")

    height = get_true_value(cfg["task"]["sample"]["height"])
    width = get_true_value(cfg["task"]["sample"]["width"])
    down_sampling_ratio = get_true_value(cfg["task"]["sample"]["down_sampling_ratio"])

    logger(f"    height: {height}")
    logger(f"    width: {width}")
    logger(f"    down_sampling_ratio: {down_sampling_ratio}")

    num_inference_step = get_true_value(cfg["task"]["sample"]["num_inference_step"])
    guidance_scale = get_true_value(cfg["task"]["sample"]["guidance_scale"])

    logger(f"    num_inference_step: {num_inference_step}")
    logger(f"    guidance_scale: {guidance_scale}")

    logger(
        f"[Sample] Loading finished. "
        "\n"
    )

    # ---------= [Task] =---------
    logger(f"[Task] Loading started. ")
    
    num_sample_per_prompt = get_true_value(cfg["task"]["task"]["num_sample_per_prompt"])
    batch_size = get_true_value(cfg["task"]["task"]["batch_size"])

    logger(f"    num_sample_per_prompt: {num_sample_per_prompt}")
    logger(f"    batch_size: {batch_size}")

    logger(
        f"[Task] Loading finished. "
        "\n"
    )

    # ---------= [Save Sample] =---------
    logger(f"[Save Sample] Loading started. ")
    
    save_sample_root_path = get_true_value(cfg["task"]["save_sample"]["save_sample_root_path"])

    logger(f"    save_sample_root_path: {save_sample_root_path}")

    logger(
        f"[Save Sample] Loading finished. "
        "\n"
    )

    # ---------= [All Components Loaded] =---------
    logger(
        f"All components loaded. "
        "\n"
    )

    # ---------= [Prepare Pipeline] =---------
    (
        pipeline_category_name, 
        pipeline_type_name
    ) = get_pipeline_category_and_type(
        pipeline_path = pipeline_path
    )

    pipeline = load_pipeline(
        pipeline_type = pipeline_type, 
        pipeline_path = pipeline_path, 
        torch_dtype = pipeline_torch_dtype, 
        variant = pipeline_variant
    )

    pipeline.scheduler = load_scheduler(
        pipeline = pipeline, 
        scheduler_type = scheduler_type
    )
    
    inference_step_minus_one = get_inference_step_minus_one(scheduler_type)

    if pipeline_category_name == "sd_family":
        from my_diffusers.pipeline_stable_diffusion import register_pipeline_stable_diffusion

        # register custom methods: `forward()`, `prepare_everything()`, `get_noise_pred()`, `step()`
        register_pipeline_stable_diffusion(pipeline)
    
    elif pipeline_category_name == "sdxl_family":
        from my_diffusers.pipeline_stable_diffusion_xl import register_pipeline_stable_diffusion_xl

        # register custom methods: `forward()`, `prepare_everything()`, `get_noise_pred()`, `step()`
        register_pipeline_stable_diffusion_xl(pipeline)

    elif pipeline_category_name == "sd_3_family":
        from my_diffusers.pipeline_stable_diffusion_3 import register_pipeline_stable_diffusion_3

        # register custom methods: `forward()`, `prepare_everything()`, `get_noise_pred()`, `step()`
        register_pipeline_stable_diffusion_3(pipeline)

    elif pipeline_category_name == "pixart_alpha_family":
        from my_diffusers.pipeline_pixart_alpha import register_pipeline_pixart_alpha

        # register custom methods: `forward()`, `prepare_everything()`, `get_noise_pred()`, `step()`
        register_pipeline_pixart_alpha(pipeline)

    else:
        raise NotImplementedError(
            f"Unsupported `pipeline_category_name`, got `{pipeline_category_name}`. "
        )
    
    register_scheduling_ddim(
        scheduler = pipeline.scheduler
    )
    
    # move to GPU after registration
    pipeline.to(device)

    # save VRAM by offloading the model to CPU
    pipeline.enable_model_cpu_offload()

    logger(f"    pipeline: {type(pipeline)}")
    # logger(f"    pipeline: {pipeline}")

    # ---------= [Prompt List] =---------
    prompt_manager = get_prompt_manager(
        prompt_manager_dict = prompt_manager_dict
    )

    prompt_manager.load_prompt_list()
    prompt_manager.prepare_everything(
        shuffle = True
    )

    prompt_list = prompt_manager.prompt_list
    folder_name_list = prompt_manager.folder_name_list
    
    if num_prompt is not None:
        prompt_list = prompt_list[: num_prompt]
        folder_name_list = folder_name_list[: num_prompt]
    else:
        num_prompt = len(prompt_list)
    
    # prompt_tuple_list: (prompt_idx, prompt)
    # prompt_tuple_list.shape = (num_prompt, 2)
    prompt_tuple_list = [
        (i, prompt) \
            for i, prompt in enumerate(prompt_list)
    ]
    
    # prompt_tuple_list.shape = (num_prompt * num_sample_per_prompt, 2)
    prompt_tuple_list = np.repeat(
        np.array(
            prompt_tuple_list, 
            dtype = object
        ), 
        num_sample_per_prompt, 

        axis = 0
    ).tolist()

    # ---------= [Prepare Task] =---------
    num_sample = num_prompt * num_sample_per_prompt

    num_batch = (num_sample + batch_size - 1) // batch_size

    # ---------= [Prepare Init Latent] =---------
    latent_num_channel = 16 if (pipeline_category_name == "sd_3_family") \
        else 4
    latent_height = height // down_sampling_ratio
    latent_width = width // down_sampling_ratio

    latent_shape = (latent_num_channel, latent_height, latent_width)
    latent_shape_str = f"{latent_num_channel}_{latent_height}_{latent_width}"

    init_latent_root_path = Path(init_latent_root_path)
    init_latent_root_path = init_latent_root_path / latent_shape_str

    if not init_latent_random:
        init_latent_seed_list = prepare_seed_list(
            seed_list = init_latent_seed_list, 
            target_length = num_sample_per_prompt, 

            auto_inrement = init_latent_seed_auto_increment
        )

        # init_latent_path_list = [
        #     init_latent_root_path / f"{seed}.pkl" \
        #         for seed in init_latent_seed_list
        # ]

        # init_latent_list = [
        #     load_pkl_cached(init_latent_path).to(device) \
        #         for init_latent_path in init_latent_path_list
        # ]

        init_latent_list = [
            get_latent(
                shape = (latent_num_channel, latent_height, latent_width), 
                
                seed = seed, 

                device = device, 

                dtype = pipeline_torch_dtype
            ) \
                for seed in init_latent_seed_list
        ]

    # ---------= [Prepare Eps] =---------
    eps_root_path = Path(eps_root_path)
    eps_root_path = eps_root_path / latent_shape_str

    if not eps_random:
        eps_seed_list = prepare_seed_list(
            seed_list = eps_seed_list, 
            target_length = num_inference_step, 

            auto_inrement = eps_seed_auto_increment
        )

        # eps_path_list = [
        #     eps_root_path / f"{seed}.pkl" \
        #         for seed in eps_seed_list
        # ]

        # eps_list = [
        #     load_pkl_cached(eps_path).to(device) \
        #         for eps_path in eps_path_list
        # ]

        eps_list = [
            get_latent(
                shape = (latent_num_channel, latent_height, latent_width), 
                
                seed = seed, 

                device = device, 

                dtype = pipeline_torch_dtype
            ) \
                for seed in eps_seed_list
        ]

    # ---------= [Prepare Eta] =---------
    if not eta_random:
        eta_list = list(eta_list)

        eta_list = tsfm_to_1d_array(
            array = eta_list, 
            target_length = num_inference_step, 

            dtype = pipeline_torch_dtype
        ).tolist()

        eta_list = [
            torch.tensor(
                [eta], 
                
                dtype = get_attr("torch", pipeline_torch_dtype), 
                device = device
            ) \
                for eta in eta_list
        ]
    
    # ---------= [Prepare Save Sample] =---------
    save_sample_root_path = Path(save_sample_root_path)
    pipeline_root_path = save_sample_root_path / pipeline_type_name

    dataset_type_str = prompt_manager_dict["prompt_manager_type"]
    dataset_root_path = pipeline_root_path / dataset_type_str

    if exp_name in ["ddpm", "deterministic_ddim"]:
        exp_name_str = exp_name
        if exp_name == "ddpm":
            if not eps_random:
                exp_name_str = f"{exp_name_str}_{eps_seed_list[0]}"
            else:
                exp_name_str = f"{exp_name_str}_{eps_seed_st}"

        exp_root_path = dataset_root_path / f"{num_inference_step}" / exp_name_str
    else:
        exp_root_path = dataset_root_path / f"{num_inference_step}" / exp_time_str
    
    # ---------= [Prepare Task Cfg] =---------
    cfg_dict = {
        "global_variable": {
            "exp_name": exp_name, 
            "start_time": start_time, 
            "device": device, 
            "seed": seed
        }, 

        "pipeline": {
            "pipeline_type": pipeline_type, 
            "pipeline_path": pipeline_path, 
            "pipeline_torch_dtype": pipeline_torch_dtype, 
            "pipeline_variant": pipeline_variant, 

            "pipeline_category_name": pipeline_category_name, 
            "pipeline_type_name": pipeline_type_name, 

            "scheduler_type": scheduler_type
        }, 

        "init_latent": {
            "init_latent_root_path": init_latent_root_path, 
            "init_latent_seed": None
        }, 

        "eps": {
            "eps_root_path": eps_root_path, 
            "eps_seed_list": []
        }, 

        "eta_list": [], 

        "sample": {
            # "prompt": prompt, 
            "prompt_2": prompt_2, 
            "negative_prompt": negative_prompt, 
            "negative_prompt_2": negative_prompt_2, 

            "height": height, 
            "width": width, 

            "num_inference_step": num_inference_step, 
            "inference_step_minus_one": inference_step_minus_one, 

            "guidance_scale": guidance_scale
        }, 

        "task": {
            "num_sample_per_prompt": num_sample_per_prompt, 
            "batch_size": batch_size
        }, 

        "save_sample": {
            "save_sample_root_path": save_sample_root_path, 
            "pipeline_type_name": pipeline_type_name, 
            # "folder_name": folder_name, 
        }
    }

    # ---------= [Wall-clock Time Cost] =---------
    time_st = time.time()

    # ---------= [Do Sampling] =---------
    def implement_batch(
        batch_idx: int, 
        true_batch_size: int, 

        arg_dict: Optional[Dict] = None
    ) -> List[torch.Tensor]:
        # ---------= [Prepare Batch Prompt List] =---------
        sample_idx_st = batch_idx * batch_size
        sample_id_ed_plus_one = sample_idx_st + true_batch_size

        # batch_prompt_tuple_list.shape = (true_batch_size, 2)
        batch_prompt_tuple_list = prompt_tuple_list[sample_idx_st: sample_id_ed_plus_one]

        logger(
            f"[Batch {batch_idx}] batch_prompt_tuple_list: {batch_prompt_tuple_list}"
        )

        # batch_prompt_list.shape = (true_batch_size, )
        batch_prompt_list = [
            prompt_tuple[1] \
                for prompt_tuple in batch_prompt_tuple_list
        ]

        # ---------= [Prepare Batch Latent List] =---------
        init_latent_seed_list = arg_dict["init_latent_seed_list"]

        # init_latent_path_list = [
        #     init_latent_root_path / f"{seed}.pkl" \
        #         for seed in init_latent_seed_list
        # ]

        # init_latent_list = [
        #     load_pkl_cached(init_latent_path).to(device) \
        #         for init_latent_path in init_latent_path_list
        # ]

        init_latent_list = [
            get_latent(
                shape = (latent_num_channel, latent_height, latent_width), 
                
                seed = seed, 

                device = device, 

                dtype = pipeline_torch_dtype
            ) \
                for seed in init_latent_seed_list
        ]

        # batch_latent_list.shape = (true_batch_size, 4, latent_height, latent_width)
        batch_latent_list = torch.stack(init_latent_list)

        # ---------= [Prepare Eps Seed List List] =---------
        eps_seed_list_list = arg_dict["eps_seed_list_list"]

        # ---------= [Prepare Eta List List] =---------
        eta_list_list = arg_dict["eta_list_list"]
        
        # ---------= [Prepare Everything] =---------
        
        (
            # SD v1.4, SDXL
            #     (no CFG) prompt_emb_list.shape = (true_batch_size, 77, 1024)
            #     (CFG) prompt_emb_list.shape = (2 * true_batch_size, 77, 1024)
            # SD v3.5 medium
            #     (no CFG) prompt_emb_list.shape = (true_batch_size, 333, 4096)
            #     (CFG) prompt_emb_list.shape = (2 * true_batch_size, 333, 4096)
            prompt_emb_list, 

            param_dict, 
            
            # batch_latent_list.shape = (true_batch_size, 4, latent_height, latent_width)
            batch_latent_list
        ) = pipeline.prepare_everything(
            prompt = batch_prompt_list, 
            prompt_2 = [prompt_2] * true_batch_size if (prompt_2 is not None) \
                else None, 
            negative_prompt = [negative_prompt] * true_batch_size if (negative_prompt is not None) \
                else None, 
            negative_prompt_2 = [negative_prompt_2] * true_batch_size if (negative_prompt_2 is not None) \
                else None, 

            height = height, width = width, 
            guidance_scale = guidance_scale, 
            num_inference_steps = num_inference_step, 

            num_images_per_prompt = 1, 

            latents = batch_latent_list, 

            return_dict = False, 

            inference_step_minus_one = inference_step_minus_one
        )
        
        # ---------= [Do Denoising] =---------
        for timestep_idx in tqdm(
            range(num_inference_step), 
            desc = f"[Denoising]"
        ):
            # ---------= [Prepare Batch Eps List] =---------
            eps_seed_list = eps_seed_list_list[:, timestep_idx]

            # eps_path_list = [
            #     eps_root_path / f"{seed}.pkl" \
            #         for seed in eps_seed_list
            # ]

            # eps_list = [
            #     load_pkl_cached(eps_path).to(device) \
            #         for eps_path in eps_path_list
            # ]

            eps_list = [
                get_latent(
                    shape = (latent_num_channel, latent_height, latent_width), 
                    
                    seed = seed, 

                    device = device, 

                    dtype = pipeline_torch_dtype
                ) \
                    for seed in eps_seed_list
            ]
            
            # batch_eps_list.shape = (true_batch_size, 4, latent_height, latent_width)
            batch_eps_list = torch.stack(eps_list) \
                .to(device)

            # ---------= [Prepare Batch Eta List] =---------
            eta_list = eta_list_list[:, timestep_idx]

            eta_list = [
                torch.tensor(
                    [eta], 
                    
                    dtype = get_attr("torch", pipeline_torch_dtype), 
                    device = device
                ) \
                    for eta in eta_list
            ]

            # batch_eta_list.shape = (true_batch_size, 1)
            batch_eta_list = torch.stack(eta_list) \
                .squeeze(-1)
            
            # ---------= [Step] =---------
            # sample_idx_list = list(range(true_batch_size))
            
            (
                # noise_pred.shape = (true_batch_size, 4, latent_height, latent_width)
                batch_noise_pred, 

                # timestep.shape = (true_batch_size, )
                batch_timestep
            ) = pipeline.get_noise_pred(
                param_dict = param_dict, 
                prompt_emb_list = prompt_emb_list, 
                # sample_idx_list = sample_idx_list, 

                latent_list = batch_latent_list, 
                
                timestep_list = None, 
                timestep_idx_list = timestep_idx
            )
            
            batch_latent_list = pipeline.step(
                latent_list = batch_latent_list, 

                noise_pred = batch_noise_pred, 
                
                timestep = batch_timestep, 
                prev_timestep = None, 

                eta_list = batch_eta_list, 
                eps = batch_eps_list
            )
            
            # ---------= [Clean Up] =---------
            del batch_eta_list
            del batch_eps_list
            del batch_noise_pred, batch_timestep
            gc.collect()
            torch.cuda.empty_cache()

            # `for timestep_idx` done
            pass

        # ---------= [Clean Up] =---------
        del param_dict
        gc.collect()
        torch.cuda.empty_cache()

        # `implement_batch()` done
        return batch_latent_list


    for batch_idx in tqdm(
        range(num_batch), 

        desc = "[Sampling]"
    ):
        # ---------= [Sample Batch] =---------
        if (batch_idx < num_batch - 1) or (num_sample % batch_size == 0):
            true_batch_size = batch_size
        else:
            true_batch_size = num_sample % batch_size
        
        sample_idx_st = batch_idx * batch_size

        # ---------= [Prepare Init Latent Seed List] =---------
        arg_dict = {}

        init_latent_seed_list_ = []

        if init_latent_random:
            init_latent_seed_list_ = [
                random.randint(init_latent_seed_st, init_latent_seed_ed) \
                    for _ in range(true_batch_size)
            ]
        else:
            for sample_idx in range(true_batch_size):
                true_sample_idx = sample_idx_st + sample_idx
                prompt_true_sample_idx = true_sample_idx % num_sample_per_prompt

                init_latent_seed_list_.append(
                    init_latent_seed_list[prompt_true_sample_idx]
                )

                # `for sample_idx` done
                pass

        arg_dict["init_latent_seed_list"] = init_latent_seed_list_

        # ---------= [Prepare Eps Seed List List] =---------
        eps_seed_list_list = []

        for sample_idx in range(true_batch_size):
            if eps_random:
                eps_seed_list = [
                    random.randint(eps_seed_st, eps_seed_ed) \
                        for _ in range(num_inference_step)
                ]

                eps_seed_list_list.append(eps_seed_list)
            else:
                eps_seed_list_list.append(eps_seed_list[:])

                # `for sample_idx` done
                pass
        
        eps_seed_list_list = torch.tensor(eps_seed_list_list)

        arg_dict["eps_seed_list_list"] = eps_seed_list_list

        # ---------= [Prepare Eta List List] =---------
        eta_list_list = []

        for sample_idx in range(true_batch_size):
            if eta_random:
                eta_list = [
                    random.uniform(eta_st, eta_ed) \
                        for _ in range(num_inference_step)
                ]

                eta_list_list.append(eta_list)
            else:
                eta_list_list.append(eta_list[:])
            
            # `for sample_idx` done
            pass
        
        eta_list_list = torch.tensor(eta_list_list)
        
        arg_dict["eta_list_list"] = eta_list_list
        
        # ---------= [Do Sampling] =---------
        batch_img_latent_list = implement_batch(
            batch_idx = batch_idx, 
            true_batch_size = true_batch_size, 

            arg_dict = arg_dict
        )
        
        # ---------= [Sample Sample] =---------
        batch_img_pil_list = img_latent_to_pil(
            img_latent_list = batch_img_latent_list, 
            pipeline = pipeline, 

            batch_size = vae_decode_batch_size
        )

        with cf.ThreadPoolExecutor(
            max_workers = concurrent_max_worker
        ) as executor:
            for sample_idx in tqdm(
                range(true_batch_size), 
                
                desc = f"[Saving Sample]"
            ):
                true_sample_idx = sample_idx_st + sample_idx
                prompt_true_sample_idx = true_sample_idx % num_sample_per_prompt
            
                prompt_idx, prompt = prompt_tuple_list[true_sample_idx]
                
                folder_name = folder_name_list[prompt_idx]
                
                folder_root_path = exp_root_path / folder_name
                cfg_root_path = folder_root_path / "cfg"
                png_root_path = folder_root_path / "png"

                # ---------= [Save Task Cfg] =---------
                cfg_dict["init_latent"]["init_latent_seed"] = arg_dict["init_latent_seed_list"][sample_idx]
                cfg_dict["eps"]["eps_seed_list"] = arg_dict["eps_seed_list_list"][sample_idx]
                cfg_dict["eta_list"] = arg_dict["eta_list_list"][sample_idx]
                cfg_dict["sample"]["prompt"] = prompt
                cfg_dict["save_sample"]["folder_name"] = folder_name
                
                cfg_dict = convert_numpy_type_to_native_type(cfg_dict)
                save_yaml(
                    cfg_dict, 

                    yaml_root_path = cfg_root_path, 
                    yaml_filename = f"{prompt_true_sample_idx}.yaml"
                )

                # ---------= [Save Sample] =---------
                param_dict = {
                    "pil": batch_img_pil_list[sample_idx], 

                    "png_root_path": png_root_path, 
                    "png_filename": f"{prompt_true_sample_idx}.png"
                }
                
                future = executor.submit(
                    save_pil_as_png, 
                    **param_dict
                )

                try:
                    future_res = future.result()
                except Exception as e:
                    logger(
                        f"`save_pil_as_png()` throws an exception: `{e}`. ", 
                        log_type = "error"
                    )

                # ---------= [Clean Up] =---------        
                del param_dict
                gc.collect()
                torch.cuda.empty_cache()

                # `for sample_idx` done
                pass
        
        # ---------= [Clean Up] =---------
        del arg_dict["init_latent_seed_list"]
        del arg_dict["eps_seed_list_list"]
        del arg_dict["eta_list_list"]
        del arg_dict
        del batch_img_latent_list
        del batch_img_pil_list
        gc.collect()
        torch.cuda.empty_cache()

        # goto `for batch_idx`
        pass

    # ---------= [Compute Wall-clock Time Cost] =---------
    time_ed = time.time()
    time_cost = time_ed - time_st

    time_cost_dict = {
        "time_cost": round(time_cost)
    }
    save_yaml(
        time_cost_dict, 

        yaml_root_path = exp_root_path, 
        yaml_filename = f"time_cost.yaml"
    )

    logger(
        f"Sampling finished, wall-clock time cost: {round(time_cost)} second(s). "
    )

    # `run_sample_scheduled_implement()` done
    pass

def run_sample_scheduled(
    cfg: DictConfig
):
    run_sample_scheduled_implement(cfg)

    # `run_sample_scheduled()` done
    pass
