from util.logger import logger

from typing import Optional, Tuple, Union, List, Dict

from abc import ABC, abstractmethod

from PIL import Image

import torch

import gc

from util.basic_util import get_attr
from util.torch_util import (
    get_list_slicing, 
    tsfm_to_1d_array
)
from util.pipeline_util import img_latent_to_pil

from diffusers.schedulers.scheduling_ddim import DDIMScheduler
from diffusers.pipelines import (
    StableDiffusionXLPipeline, 
    PixArtAlphaPipeline
)


class RewardModel(ABC):
    # _nfe_cal_dynamics = 0
    # _nfe_cal_final_reward = 0
    

    # @staticmethod
    # def get_nfe_cal_dynamics(
    # ) -> int:
    #     # `get_nfe_cal_dynamics()` done
    #     return RewardModel._nfe_cal_dynamics
    

    # @staticmethod
    # def get_nfe_cal_final_reward(
    # ) -> int:
    #     # `get_nfe_cal_final_reward()` done
    #     return RewardModel._nfe_cal_final_reward
    

    # @staticmethod
    # def get_nfe(
    # ) -> Tuple[int, int]:
    #     # `get_nfe()` done
    #     return (
    #         RewardModel._nfe_cal_dynamics, 
    #         RewardModel._nfe_cal_final_reward
    #     )


    # @staticmethod
    # def reset_nfe(
    # ) -> int:
    #     RewardModel._nfe_cal_dynamics = 0
    #     RewardModel._nfe_cal_final_reward = 0

    #     # `reset_nfe()` done
    #     pass
    
    def __init__(
        self, 

        # ---------= [Pipeline] =---------
        pipeline: Optional["StableDiffusionPipeline"] = None, 
        num_inference_step: Optional[int] = None, 

        # ---------= [Param] =---------
        prompt_emb_list: List[torch.Tensor] = None, 
        param_dict: Optional[Dict] = None, 
        num_sample_per_prompt: int = None, 

        # ---------= [Reward] =---------
        reward_shape: Optional[Tuple] = (1, ), 
        reward_dtype: Optional[str] = "float32", 
        offload_to_cpu: Optional[bool] = True, 

        # ---------= [Parallel] =---------
        cal_dynamics_batch_size: Optional[int] = 1, 
        cal_intermediate_reward_batch_size: Optional[int] = 1, 
        cal_final_reward_batch_size: Optional[int] = 1, 

        # ---------= [Reward Shaping] =---------
        # `reward_shaping_policy` in [
        #     "disabled", 
        #     "latent_reward", 
        #     "potential_based", 
        #     "skipping"
        # ]
        reward_shaping_policy: str = "disabled", 
        # `cal_intermediate_reward_policy` in [
        #     "immediate_posterior_mean", 
        #     "immediate_score_function", 
        #     "look_ahead", 
        #     "sequential", 
        #     "discount", 
        # ]
        # potential_exp_growing: bool = False, 
        # potential_exp_base: float = 1.0, 
        cal_intermediate_reward_policy: Optional[str] = "sequential", 

        device: Optional[str] = "cpu", 

        vae_decode_batch_size: Optional[int] = 10, 

        **arg_dict
    ):
        self._device = device
        self._vae_decode_batch_size = vae_decode_batch_size

        # ---------= [Pipeline] =---------
        self._pipeline = pipeline

        self._num_sample_per_prompt = 1

        if self._pipeline is not None:
            if not isinstance(prompt_emb_list, torch.Tensor):
                prompt_emb_list = torch.stack(
                    prompt_emb_list, 

                    dtype = self._pipeline.dtype, 
                    device = self._pipeline.device
                )
            self._prompt_emb_list = prompt_emb_list

            self._param_dict = param_dict
            self._num_sample_per_prompt = num_sample_per_prompt

            self._num_inference_step = num_inference_step

            self._scheduler = self._pipeline.scheduler
            if not isinstance(self._scheduler, DDIMScheduler):
                raise ValueError(
                    f"Only support `DDIMScheduler`. "
                )
        
            self._more_param_op = None
            if isinstance(self._pipeline, StableDiffusionXLPipeline):
                self._more_param_op = "sdxl"
                self._prepare_sdxl_param()
            elif isinstance(self._pipeline, PixArtAlphaPipeline):
                self._more_param_op = "pixart_alpha"
                self._prepare_pixart_alpha_param()
        
        # ---------= [NFE] =---------
        do_cfg = False
        num_prompt = 1

        if self._pipeline is not None:
            do_cfg = self._pipeline.do_classifier_free_guidance
            num_prompt = self._prompt_emb_list.shape[0] // (1 + do_cfg)
        
        self._do_cfg = do_cfg
        self._num_prompt = num_prompt

        num_sample = self._num_prompt * self._num_sample_per_prompt

        self.nfe_cal_dynamics_list = [0] * num_sample
        self.nfe_cal_intermediate_reward_list = [0] * num_sample
        self.nfe_cal_final_reward_list = [0] * num_sample

        # ---------= [Reward] =---------
        self._reward_shape = reward_shape
        
        self._reward_dtype = reward_dtype
        if isinstance(self._reward_dtype, str):
            self._reward_dtype = get_attr("torch", self._reward_dtype)

        self._offload_to_cpu = offload_to_cpu

        # ---------= [Parallel] =---------
        self._cal_dynamics_batch_size = cal_dynamics_batch_size
        self._cal_intermediate_reward_batch_size = cal_intermediate_reward_batch_size
        self._cal_final_reward_batch_size = cal_final_reward_batch_size

        # ---------= [Cal Intermediate Reward] =---------
        if reward_shaping_policy not in [
            "disabled", 

            "latent_reward", 
            
            # "potential_based", 
            # "skipping"
        ]:
            raise ValueError(
                f"Unsupported `reward_shaping_policy`, "
                f"got `{reward_shaping_policy}`. "
            )

        if cal_intermediate_reward_policy not in [
            "immediate_posterior_mean", 
            "immediate_score_function", 
            
            "look_ahead", 

            "sequential", 

            "discount"
        ]:
            raise ValueError(
                f"Unsupported `cal_intermediate_reward_policy`, "
                f"got `{cal_intermediate_reward_policy}`. "
            )

        self.reward_shaping_policy = reward_shaping_policy
        self.cal_intermediate_reward_policy = cal_intermediate_reward_policy

        # if potential_exp_growing and (self.reward_shaping_policy != "potential_based"):
        #     raise ValueError(
        #         f"Exponentially growing potential function can only be enabled "
        #         f"when `reward_shaping_policy = \"potential_based\"`. "
        #     )
        # self.potential_exp_growing = potential_exp_growing
        # self.potential_exp_base = potential_exp_base

        # ---------= [Clip] =---------
        self.clip_range = None

        # ---------= [Scaling] =---------
        self.norm_constant = 1.0
        self.bias_constant = 0.0

        # `__init__()` done
        pass


    def _prepare_sdxl_param(
        self
    ):
        """
        Func: 
            Prepare more parameters for SDXL forward. 
        """

        add_text_embeds = self._param_dict.pop("add_text_embeds")
        add_time_ids = self._param_dict.pop("add_time_ids")

        self._add_text_emb_list = add_text_embeds
        self._add_time_id_list = add_time_ids

        # `_prepare_sdxl_param()` done
        pass


    def _prepare_pixart_alpha_param(
        self
    ):
        """
        Func: 
            Prepare more parameters for PixArt-Alpha forward. 
        """

        prompt_attention_mask = self._param_dict.pop("prompt_attention_mask")

        added_cond_kwargs = self._param_dict.pop("added_cond_kwargs")
        resolution = added_cond_kwargs["resolution"]
        aspect_ratio = added_cond_kwargs["aspect_ratio"]

        self._prompt_attention_mask_list = prompt_attention_mask

        self._resolution_list = resolution
        self._aspect_ratio_list = aspect_ratio

        # `_prepare_sdxl_param()` done
        pass

    def _get_concat_param_list(
        self, 
    
        param_list: Union[List[torch.Tensor], torch.Tensor], 
        sample_idx_list: List[int]
    ) -> torch.Tensor:
        """
        Func:
            Get concatenate parameters according to `sample_idx_list`. 

        Ret:
            `concat_param_list` (`torch.Tensor`): The derived `concat_param_list`. 
        """

        if self._do_cfg:
            negative_param_list = []
            positive_param_list = []
            
            for sample_idx in sample_idx_list:
                prompt_idx = sample_idx // self._num_sample_per_prompt

                negative_param_list.append(
                    param_list[prompt_idx]
                )

                positive_param_list.append(
                    param_list[prompt_idx + self._num_prompt]
                )

                # goto `for sample_idx`
                pass
        
            concat_param_list = torch.stack(
                negative_param_list + positive_param_list
            )
        else:
            concat_param_list = torch.stack(
                [
                    param_list[sample_idx % self._num_sample_per_prompt] \
                        for sample_idx in sample_idx_list
                ]
            )

        # ---------= [Clean Up] =---------
        if self._do_cfg:
            del negative_param_list, positive_param_list

        # `_get_concat_param_list()` done
        return concat_param_list
    

    # (discarded)
    # def _get_prompt_emb_list(
    #     self, 

    #     sample_idx_list: List[int]
    # ) -> torch.Tensor:
    #     """
    #     Func:
    #         Get `prompt_emb_list` according to `sample_idx_list`. 

    #     Ret:
    #         `prompt_emb_list` (`torch.Tensor`): The derived `prompt_emb_list`. 
    #     """

    #     if self._do_cfg:
    #         negative_prompt_emb_list = []
    #         prompt_emb_list = []
            
    #         for sample_idx in sample_idx_list:
    #             prompt_idx = sample_idx // self._num_sample_per_prompt

    #             negative_prompt_emb_list.append(
    #                 self._prompt_emb_list[prompt_idx]
    #             )

    #             prompt_emb_list.append(
    #                 self._prompt_emb_list[prompt_idx + self._num_prompt]
    #             )

    #             # goto `for sample_idx`
    #             pass
        
    #         prompt_emb_list = torch.stack(
    #             negative_prompt_emb_list + prompt_emb_list
    #         )
    #     else:
    #         prompt_emb_list = torch.stack(
    #             [
    #                 self._prompt_emb_list[sample_idx % self._num_sample_per_prompt] \
    #                     for sample_idx in sample_idx_list
    #             ]
    #         )

    #     # `_get_prompt_emb_list()` done
    #     return prompt_emb_list


    @torch.no_grad()
    def batch_cal_dynamics(
        self, 

        sample_idx_list: List[int], 

        # `z_{t_i}`
        latent_list: Union[torch.Tensor, List[torch.Tensor]], 

        # `t_i`
        timestep_list: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None, 

        # `i`
        timestep_idx_list: Optional[Union[int, List[int], torch.Tensor]] = None, 

        # `t_{i + 1}`
        prev_timestep_list: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None, 

        eta_list: Union[float, torch.Tensor] = None, 
        eps_list: torch.Tensor = None
    ) -> torch.Tensor:
        """
        NB:
            The input `latent_list` will be modified. 

        Func:
            Compute the next state given the current state and action. 

        Ret:
            `latent_list` (`torch.Tensor`): The next state.
        """

        if not isinstance(eta_list, torch.Tensor):
            eta_list = torch.tensor(
                [
                    [eta_list]
                ], 

                dtype = latent_list.dtype, 
                device = latent_list.device
            )

        num_latent = len(latent_list)
        num_eta = len(eta_list)
        
        if num_eta == 1:
            action_shape_length = len(eta_list.shape[1: ])
            
            eta_list = eta_list.repeat(
                (num_latent, *([1] * action_shape_length))
            )
        elif num_eta != num_latent:
            raise ValueError(
                f"The length of `eta_list` does not match the length of `latent_list`, "
                f"got `{num_eta}` and `{num_latent}`. "
            )

        if timestep_list is not None:
            timestep_list = tsfm_to_1d_array(
                array = timestep_list, 
                target_length = num_latent, 

                dtype = torch.int32, 
                device = "cpu"
            )

        if timestep_idx_list is not None:
            timestep_idx_list = tsfm_to_1d_array(
                array = timestep_idx_list, 
                target_length = num_latent, 

                dtype = torch.int32, 
                device = "cpu"
            )

        batch_size = self._cal_dynamics_batch_size
        

        def implement_batch(
            batch_idx: int, 
            true_batch_size: int
        ) -> List[torch.Tensor]:
            # ---------= [Prepare Everything] =---------
            sample_idx_st = batch_idx * batch_size
            sample_idx_ed_plus_one = sample_idx_st + true_batch_size

            batch_latent_list = get_list_slicing(
                latent_list, 

                st = sample_idx_st, 
                ed = sample_idx_ed_plus_one
            )

            batch_prev_timestep_list = None
            if prev_timestep_list is not None:
                batch_prev_timestep_list = get_list_slicing(
                    prev_timestep_list, 

                    st = sample_idx_st, 
                    ed = sample_idx_ed_plus_one
                )

            batch_timestep_list = None
            if timestep_list is not None:
                batch_timestep_list = get_list_slicing(
                    timestep_list, 

                    st = sample_idx_st, 
                    ed = sample_idx_ed_plus_one
                )
            
            batch_timestep_idx_list = None
            if timestep_idx_list is not None:
                batch_timestep_idx_list = get_list_slicing(
                    timestep_idx_list, 

                    st = sample_idx_st, 
                    ed = sample_idx_ed_plus_one
                )

            batch_eta_list = get_list_slicing(
                eta_list, 

                st = sample_idx_st, 
                ed = sample_idx_ed_plus_one
            )

            if eps_list is not None:
                batch_eps_list = get_list_slicing(
                    eps_list, 

                    st = sample_idx_st, 
                    ed = sample_idx_ed_plus_one
                )
            else:
                batch_eps_list = None
            
            batch_sample_idx_list = sample_idx_list[sample_idx_st: sample_idx_ed_plus_one]

            # ---------= [Step] =---------
            # prompt_emb_list = self._get_prompt_emb_list(
            #     sample_idx_list = batch_sample_idx_list
            # )

            prompt_emb_list = self._get_concat_param_list(
                param_list = self._prompt_emb_list, 
                sample_idx_list = batch_sample_idx_list
            )

            if self._more_param_op == "sdxl":
                add_text_emb_list = self._get_concat_param_list(
                    param_list = self._add_text_emb_list, 
                    sample_idx_list = batch_sample_idx_list 
                )
                add_time_id_list = self._get_concat_param_list(
                    param_list = self._add_time_id_list, 
                    sample_idx_list = batch_sample_idx_list
                )

                self._param_dict["add_text_embeds"] = add_text_emb_list
                self._param_dict["add_time_ids"] = add_time_id_list
            elif self._more_param_op == "pixart_alpha":
                prompt_attention_mask_list = self._get_concat_param_list(
                    param_list = self._prompt_attention_mask_list, 
                    sample_idx_list = batch_sample_idx_list 
                )

                resolution_list = self._get_concat_param_list(
                    param_list = self._resolution_list, 
                    sample_idx_list = batch_sample_idx_list 
                )
                aspect_ratio_list = self._get_concat_param_list(
                    param_list = self._aspect_ratio_list, 
                    sample_idx_list = batch_sample_idx_list 
                )
                added_cond_kwargs = {
                    "resolution": resolution_list, 
                    "aspect_ratio": aspect_ratio_list
                }

                self._param_dict["prompt_attention_mask"] = prompt_attention_mask_list
                self._param_dict["added_cond_kwargs"] = added_cond_kwargs

            (
                # batch_noise_pred_list.shape = (true_batch_size, 4, latent_height, latent_width)
                batch_noise_pred_list, 

                # batch_timestep_list.shape = (true_batch_size, )
                batch_timestep_list
            ) = self._pipeline.get_noise_pred(
                param_dict = self._param_dict, 
                prompt_emb_list = prompt_emb_list, 

                latent_list = batch_latent_list, 
                
                timestep_list = batch_timestep_list, 
                timestep_idx_list = batch_timestep_idx_list
            )
            
            batch_latent_list = self._pipeline.step(
                latent_list = batch_latent_list, 

                noise_pred = batch_noise_pred_list, 
                
                timestep = batch_timestep_list, 
                prev_timestep = batch_prev_timestep_list, 

                eta_list = batch_eta_list, 

                eps = batch_eps_list
            )

            # ---------= [Clean Up] =---------
            if batch_prev_timestep_list is not None:
                del batch_prev_timestep_list
            if batch_timestep_list is not None:
                del batch_timestep_list
            del batch_timestep_idx_list
            del batch_eta_list
            del batch_eps_list
            del batch_noise_pred_list
            del batch_sample_idx_list
            if self._more_param_op == "sdxl":
                del self._param_dict["add_text_embeds"]
                del self._param_dict["add_time_ids"]
            elif self._more_param_op == "pixart_alpha":
                del self._param_dict["prompt_attention_mask"]
                del self._param_dict["added_cond_kwargs"]
            gc.collect()
            torch.cuda.empty_cache()

            # `implement_batch()` done
            return batch_latent_list


        num_latent = len(latent_list)
        num_batch = (num_latent + batch_size - 1) // batch_size

        batch_latent_list_list = []

        for batch_idx in range(num_batch):
            if (batch_idx < num_batch - 1) or (num_latent % batch_size == 0):
                true_batch_size = batch_size
            else:
                true_batch_size = num_latent % batch_size

            # batch_latent_list.shape = (true_batch_size, 4, latent_height, latent_width)
            batch_latent_list = implement_batch(
                batch_idx = batch_idx, 
                true_batch_size = true_batch_size
            )
            
            batch_latent_list_list.append(batch_latent_list)
            
            # goto `for batch_idx`
            pass

        # latent_list.shape = (num_latent, 4, latent_height, latent_width)
        latent_list = torch.vstack(batch_latent_list_list)

        # ---------= [Update NFE] =---------
        for sample_idx in sample_idx_list:
            self.nfe_cal_dynamics_list[sample_idx] += 1

            # goto `for sample_idx`
            pass

        # ---------= [Clean Up] =---------
        del batch_latent_list_list
        gc.collect()
        torch.cuda.empty_cache()

        # `batch_cal_dynamics()` done
        return latent_list

    
    @abstractmethod
    @torch.no_grad()
    def cal_final_reward_implement(
        self, 

        sample_idx_list: List[int], 

        latent_list: Optional[torch.Tensor] = None, 
        img_pil_list: Optional[List[Image.Image]] = None, 

        prompt_list: Union[str, List[str]] = None, 

        **arg_dict: Optional[Dict]
    ) -> torch.Tensor:
        """
        NB: 
            Abstract, should be implemented. 

        Func:
            Compute the final reward for the clean images. 

        Ret:
            `final_reward_list` (`torch.Tensor`): The list of the final rewards. 
                final_reward_list.shape = (num_img, 1). 
        """


        # `cal_final_reward_implement()` done
        return 


    @torch.no_grad()
    def cal_final_reward(
        self, 

        sample_idx_list: List[int] = None, 

        latent_list: Optional[torch.Tensor] = None, 
        img_pil_list: Optional[List[Image.Image]] = None, 

        prompt_list: Union[str, List[str]] = None, 

        **arg_dict: Optional[Dict]
    ) -> torch.Tensor:
        """
        Func:
            Compute the final reward for the clean images. 

        Ret:
            `final_reward_list` (`torch.Tensor`): The list of the final rewards. 
                final_reward_list.shape = (num_img, 1). 
        """

        # ---------= [Prepare Image PIL List] =---------
        if (latent_list is not None) and (img_pil_list is not None):
            logger(
                f"Both `latent_list` and `img_pil_list` are provided, "
                f"`img_pil_list` prioritizes. ", 

                log_type = "warning"
            )

        if img_pil_list is None:
            img_pil_list = img_latent_to_pil(
                img_latent_list = latent_list, 
                pipeline = self._pipeline, 

                batch_size = self._vae_decode_batch_size
            )

        # ---------= [Compute Final Reward] =---------
        num_img = len(img_pil_list)

        batch_size = self._cal_final_reward_batch_size
        
        def implement_batch(
            batch_idx: int, 
            true_batch_size: int
        ) -> List[torch.Tensor]:
            # ---------= [Prepare Everything] =---------
            sample_idx_st = batch_idx * batch_size
            sample_idx_ed_plus_one = sample_idx_st + true_batch_size

            batch_sample_idx_list = None
            if sample_idx_list is not None:
                batch_sample_idx_list = sample_idx_list[sample_idx_st: sample_idx_ed_plus_one]

            batch_latent_list = None
            if latent_list is not None:
                batch_latent_list = get_list_slicing(
                    latent_list, 

                    st = sample_idx_st, 
                    ed = sample_idx_ed_plus_one
                )

            batch_img_pil_list = None
            if img_pil_list is not None:
                batch_img_pil_list = img_pil_list[sample_idx_st: sample_idx_ed_plus_one]

            batch_prompt_list = prompt_list[sample_idx_st: sample_idx_ed_plus_one]
            
            batch_final_reward_list = self.cal_final_reward_implement(
                sample_idx_list = batch_sample_idx_list, 

                latent_list = batch_latent_list, 
                img_pil_list = batch_img_pil_list, 

                prompt_list = batch_prompt_list
            )
            
            # ---------= [Clean Up] =---------
            if batch_sample_idx_list is not None:
                del batch_sample_idx_list
            if batch_latent_list is not None:
                del batch_latent_list
            if batch_img_pil_list is not None:
                del batch_img_pil_list
            if batch_prompt_list is not None:
                del batch_prompt_list
            gc.collect()
            torch.cuda.empty_cache()

            # `implement_batch()` done
            return batch_final_reward_list


        num_batch = (num_img + batch_size - 1) // batch_size

        batch_final_reward_list_list = []

        for batch_idx in range(num_batch):
            if (batch_idx < num_batch - 1) or (num_img % batch_size == 0):
                true_batch_size = batch_size
            else:
                true_batch_size = num_img % batch_size

            # batch_final_reward_list.shape = (true_batch_size, *self.reward_shape)
            batch_final_reward_list = implement_batch(
                batch_idx = batch_idx, 
                true_batch_size = true_batch_size
            )
            
            batch_final_reward_list_list.append(batch_final_reward_list)
            
            # goto `for batch_idx`
            pass
        
        # final_reward_list.shape = (num_img, *self.reward_shape)
        final_reward_list = torch.cat(batch_final_reward_list_list)

        final_reward_list = final_reward_list.reshape(
            (num_img, *self._reward_shape)
        )
        
        # ---------= [Clip] =---------
        if self.clip_range is not None:
            final_reward_list = torch.clip(
                final_reward_list, 
                self.clip_range[0], self.clip_range[1]
            )
        
        # ---------= [Scaling] =---------
        final_reward_list /= self.norm_constant
        final_reward_list += self.bias_constant

        if self._offload_to_cpu:
            final_reward_list = final_reward_list.to("cpu")

        # ---------= [Update NFE] =---------
        if sample_idx_list is not None:
            for sample_idx in sample_idx_list:
                self.nfe_cal_final_reward_list[sample_idx] += 1

                # goto `for sample_idx`
                pass

        # ---------= [Update NFE] =---------
        del batch_final_reward_list_list
        gc.collect()
        torch.cuda.empty_cache()

        # `cal_final_reward()` done
        return final_reward_list


    @torch.no_grad()
    def cal_intermediate_reward(
        self, 

        sample_idx_list: List[int], 

        # # prev_latent_list: `z_{t_{i - 1}}`
        # # prev_latent_list.shape = (batch_size, 4, latent_height, latent_width)
        # prev_latent_list: Optional[torch.Tensor] = None, 

        # # `prev_action_list`: `\eta_{t_{i - 1}}`
        # # prev_action_list.shape = (batch_size, *action_shape)
        # prev_action_list: Optional[torch.Tensor] = None, 

        # latent_list: `z_{t_i}`
        # latent_list.shape = (batch_size, 4, latent_height, latent_width)
        latent_list: Optional[torch.Tensor] = None, 

        prompt_list: Union[str, List[str]] = None,

        # `timestep_idx_list`: `i`
        timestep_idx_list: Union[int, List[int], torch.Tensor] = None, 

        **arg_dict: Optional[Dict]
    ) -> Tuple[
        torch.Tensor, 
        torch.Tensor
    ]:
        """
        Func:
            Compute the intermediate reward for the intermediate latents. 

        Ret:
            `intermediate_reward_list` (`torch.Tensor`): The list of the shaped intermediate rewards. 
                intermediate_reward_list.shape = (num_latent, *self._reward_shape). 
            `pseudo_final_latent_list` (`torch.Tensor`): 
                The list of the predicted final latents used for computing intermediate rewards. 
                    pseudo_final_latent_list.shape = (num_latent, *self._state_shape). 
        """

        assert self.reward_shaping_policy != "disabled"

        # ---------= [Prepare `timestep_list`] =--------
        if not isinstance(timestep_idx_list, torch.Tensor):
            timestep_idx_list = torch.tensor(timestep_idx_list)
        
        num_latent = len(latent_list)
        num_inference_step = len(self._scheduler.timesteps)

        # `timestep_list`: `t_i`
        timestep_list = []

        for timestep_idx in timestep_idx_list:
            if timestep_idx < num_inference_step:
                timestep_list.append(
                    self._scheduler.timesteps[timestep_idx]
                )
            else:
                timestep_list.append(
                    torch.zeros_like(self._scheduler.timesteps[0])
                )

            # goto `for timestep_idx`
            pass

        timestep_list = torch.stack(timestep_list)

        batch_size = self._cal_intermediate_reward_batch_size
        
        def implement_batch(
            batch_idx: int, 
            true_batch_size: int
        ) -> Tuple[
            List[torch.Tensor], 
            List[torch.Tensor]
        ]:
            # ---------= [Prepare Everything] =---------
            sample_idx_st = batch_idx * batch_size
            sample_idx_ed_plus_one = sample_idx_st + true_batch_size

            batch_sample_idx_list = None
            if sample_idx_list is not None:
                batch_sample_idx_list = sample_idx_list[sample_idx_st: sample_idx_ed_plus_one]

            batch_latent_list = get_list_slicing(
                latent_list, 

                st = sample_idx_st, 
                ed = sample_idx_ed_plus_one
            )

            batch_prompt_list = prompt_list[sample_idx_st: sample_idx_ed_plus_one]

            batch_timestep_idx_list = timestep_idx_list[sample_idx_st: sample_idx_ed_plus_one]
            batch_timestep_list = timestep_list[sample_idx_st: sample_idx_ed_plus_one]

            # ---------= [Compute Intermediate Reward] =---------
            if self.cal_intermediate_reward_policy in [
                "immediate_posterior_mean", 
                "immediate_score_function"
            ]:
                if self.reward_shaping_policy == "latent_reward":
                    (
                        batch_intermediate_reward_list, 
                        batch_pseudo_final_latent_list
                    ) = self._cal_intermediate_reward_immediate(
                        sample_idx_list = batch_sample_idx_list, 

                        # `latent_list`: `z_{t_i}`
                        latent_list = batch_latent_list, 

                        prompt_list = batch_prompt_list, 

                        # `timestep_idx_list`: `i`
                        timestep_idx_list = batch_timestep_idx_list
                    )

            elif self.cal_intermediate_reward_policy == "look_ahead":
                num_look_ahead_step = arg_dict.get("num_look_ahead_step", None)

                if num_look_ahead_step is None:
                    raise ValueError(
                        f"`num_look_ahead_step` must be provided "
                        f"when `cal_intermediate_reward_policy == \"look_ahead\"`. "
                    )
                
                if self.reward_shaping_policy == "latent_reward":
                    (
                        batch_intermediate_reward_list, 
                        batch_pseudo_final_latent_list
                    ) = self._cal_intermediate_reward_look_ahead(
                        sample_idx_list = batch_sample_idx_list, 

                        # `latent_list`: `z_{t_i}`
                        latent_list = batch_latent_list, 

                        prompt_list = batch_prompt_list, 

                        # `timestep_list`: `t_i`
                        timestep_list = batch_timestep_list, 

                        num_look_ahead_step = num_look_ahead_step
                    )

            elif self.cal_intermediate_reward_policy in [
                "sequential", 
                "discount"
            ]:
                if self.cal_intermediate_reward_policy == "discount":
                    gamma = arg_dict.get("gamma", None)

                    if gamma is None:
                        raise ValueError(
                            f"`gamma` must be provided "
                            f"when `cal_intermediate_reward_policy == \"discount\"`. "
                        )
                    
                if self.reward_shaping_policy == "latent_reward":
                    (
                        batch_intermediate_reward_list, 
                        batch_pseudo_final_latent_list
                    ) = self._cal_intermediate_reward_sequential(
                        sample_idx_list = batch_sample_idx_list, 

                        # `latent_list`: `z_{t_i}`
                        latent_list = batch_latent_list, 

                        prompt_list = batch_prompt_list, 

                        # `timestep_idx_list`: `i`
                        timestep_idx_list = batch_timestep_idx_list
                    )

                if self.cal_intermediate_reward_policy == "discount":
                    batch_intermediate_reward_list = self._cal_discounted_intermediate_reward_list(
                        intermediate_reward_list = batch_intermediate_reward_list, 

                        # `timestep_idx_list`: `i`
                        timestep_idx_list = batch_timestep_idx_list, 

                        gamma = gamma
                    )

            else:
                raise NotImplementedError(
                    f"Unsupported `cal_intermediate_reward_policy`, got `{self.cal_intermediate_reward_policy}`. "
                )

            # ---------= [Clean Up] =---------
            if batch_sample_idx_list is not None:
                del batch_sample_idx_list
            del batch_latent_list
            del batch_prompt_list
            del batch_timestep_idx_list
            gc.collect()
            torch.cuda.empty_cache()

            # `implement_batch()` done
            return (
                batch_intermediate_reward_list, 
                batch_pseudo_final_latent_list
            )

        num_batch = (num_latent + batch_size - 1) // batch_size

        # intermediate_reward_list = [torch.tensor(0.0)] * num_latent
        intermediate_reward_list = []
        pseudo_final_latent_list = []
        
        for batch_idx in range(num_batch):
            if (batch_idx < num_batch - 1) or (num_latent % batch_size == 0):
                true_batch_size = batch_size
            else:
                true_batch_size = num_latent % batch_size

            # batch_intermediate_reward_list.shape = (true_batch_size, *self.reward_shape)
            (
                batch_intermediate_reward_list, 
                batch_pseudo_final_latent_list
            ) = implement_batch(
                batch_idx = batch_idx, 
                true_batch_size = true_batch_size
            )
            
            intermediate_reward_list.append(batch_intermediate_reward_list)
            pseudo_final_latent_list.append(batch_pseudo_final_latent_list)
            
            # goto `for batch_idx`
            pass

        intermediate_reward_list = torch.vstack(intermediate_reward_list)

        # if not isinstance(intermediate_reward_list, torch.Tensor):
        #     intermediate_reward_list = torch.tensor(
        #         intermediate_reward_list, 

        #         dtype = self._reward_dtype
        #     )

        intermediate_reward_list = intermediate_reward_list.reshape(
            (num_latent, *self._reward_shape)
        )

        if self._offload_to_cpu:
            intermediate_reward_list = intermediate_reward_list.to("cpu")

        pseudo_final_latent_list = torch.vstack(pseudo_final_latent_list)
        
        pseudo_final_latent_list \
            = pseudo_final_latent_list.reshape(num_latent, *pseudo_final_latent_list.shape[1: ])

        # ---------= [Update NFE] =---------
        for sample_idx in sample_idx_list:
            self.nfe_cal_intermediate_reward_list[sample_idx] += 1

            if self.reward_shaping_policy == "latent_reward":
                self.nfe_cal_final_reward_list[sample_idx] += 1

            # goto `for sample_idx`
            pass
        
        # ---------= [Clean Up] =---------
        del timestep_list
        gc.collect()
        torch.cuda.empty_cache()
        
        # `cal_intermediate_reward()` done
        return (
            intermediate_reward_list, 
            pseudo_final_latent_list
        )

    
    @torch.no_grad()
    def _cal_intermediate_reward_immediate(
        self, 

        # latent_list: `z_{t_i}`
        # latent_list.shape = (batch_size, 4, latent_height, latent_width)
        latent_list: torch.Tensor = None, 

        prompt_list: Union[str, List[str]] = None,

        # `i`
        timestep_idx_list: torch.Tensor = None, 

        **arg_dict: Optional[Dict]
    ) -> Tuple[
        torch.Tensor, 
        torch.Tensor
    ]:
        """
        Func:
            Compute the intermediate reward for the intermediate latents, 
                using the `immediate_posterior_mean` or `immediate_score_function` policy. 

        Ret:
            `intermediate_reward_list` (`torch.Tensor`): The list of the intermediate rewards. 
                intermediate_reward_list.shape = (num_latent, *self._reward_shape). 
            `pseudo_final_latent_list` (`torch.Tensor`): 
                The list of the predicted final latents used for computing intermediate rewards. 
                    pseudo_final_latent_list.shape = (num_latent, *self._state_shape). 
        """

        num_latent = latent_list.shape[0]
        
        # ---------= [Compute Immediate Reward] =---------
        sample_idx_list = arg_dict.get("sample_idx_list", None)

        if sample_idx_list is None:
            raise ValueError(
                f"`sample_idx_list` must be provided. "
            )
        
        # prompt_emb_list = self._get_prompt_emb_list(
        #     sample_idx_list = sample_idx_list
        # )

        prompt_emb_list = self._get_concat_param_list(
            param_list = self._prompt_emb_list, 
            sample_idx_list = sample_idx_list
        )

        if self._more_param_op == "sdxl":
            add_text_emb_list = self._get_concat_param_list(
                param_list = self._add_text_emb_list, 
                sample_idx_list = sample_idx_list 
            )
            add_time_id_list = self._get_concat_param_list(
                param_list = self._add_time_id_list, 
                sample_idx_list = sample_idx_list 
            )

            self._param_dict["add_text_embeds"] = add_text_emb_list
            self._param_dict["add_time_ids"] = add_time_id_list
        elif self._more_param_op == "pixart_alpha":
            prompt_attention_mask_list = self._get_concat_param_list(
                param_list = self._prompt_attention_mask_list, 
                sample_idx_list = sample_idx_list 
            )

            resolution_list = self._get_concat_param_list(
                param_list = self._resolution_list, 
                sample_idx_list = sample_idx_list 
            )
            aspect_ratio_list = self._get_concat_param_list(
                param_list = self._aspect_ratio_list, 
                sample_idx_list = sample_idx_list 
            )
            added_cond_kwargs = {
                "resolution": resolution_list, 
                "aspect_ratio": aspect_ratio_list
            }

            self._param_dict["prompt_attention_mask"] = prompt_attention_mask_list
            self._param_dict["added_cond_kwargs"] = added_cond_kwargs

        (
            # noise_pred.shape = (batch_size, 4, latent_height, latent_width)
            noise_pred, 

            # timestep.shape = (batch_size, )
            timestep
        ) = self._pipeline.get_noise_pred(
            param_dict = self._param_dict, 
            prompt_emb_list = prompt_emb_list, 

            latent_list = latent_list, 

            timestep_idx_list = timestep_idx_list
        )
        
        timestep_cpu = timestep.cpu()
        # alpha_prod_t.shape = (batch_size, )
        alpha_prod_t = self._scheduler.alphas_cumprod[timestep_cpu]

        # alpha_prod_t.shape: (batch_size, ) -> (batch_size, 1, 1, 1)
        alpha_prod_t = alpha_prod_t.unsqueeze(-1) \
            .unsqueeze(-1) \
            .unsqueeze(-1)

        alpha_prod_t = alpha_prod_t.to(
            dtype = latent_list.dtype, 
            device = latent_list.device
        )

        # sqrt_alpha_prod_t.shape = (batch_size, 1, 1, 1)
        sqrt_alpha_prod_t = alpha_prod_t ** (0.5)

        # TODO: remove
        # alpha_prod_t_prev = self._scheduler.alphas_cumprod[prev_timestep] if (prev_timestep >= 0) \
        #     else self.final_alpha_cumprod

        # beta_prod_t.shape = (batch_size, 1, 1, 1)
        beta_prod_t = 1 - alpha_prod_t

        # # sqrt_beta_prod_t.shape = (batch_size, 1, 1, 1)
        sqrt_beta_prod_t = beta_prod_t ** (0.5)

        # TODO: remove
        # sqrt_beta_prod_t = sqrt_beta_prod_t.unsqueeze(-1) \
        #     .unsqueeze(-1) \
        #     .unsqueeze(-1)
        
        # sqrt_beta_prod_t = sqrt_beta_prod_t.to(
        #     dtype = latent_list.dtype, 
        #     device = latent_list.device
        # )
        
        if self.cal_intermediate_reward_policy == "immediate_posterior_mean":
            # original_latent_pred_list.shape = (batch_size, 4, latent_height, latent_width)
            pseudo_final_latent_list = (latent_list - sqrt_beta_prod_t * noise_pred) / sqrt_alpha_prod_t
        elif self.cal_intermediate_reward_policy == "immediate_score_function":
            score_function = -noise_pred / sqrt_beta_prod_t

            # pseudo_final_latent_list.shape = (batch_size, 4, latent_height, latent_width)
            pseudo_final_latent_list = (latent_list + beta_prod_t * score_function) / sqrt_alpha_prod_t

        # ---------= [Compute Final Reward] =---------
        final_reward_list = self.cal_final_reward(
            sample_idx_list = sample_idx_list, 

            latent_list = pseudo_final_latent_list, 

            prompt_list = prompt_list
        )

        intermediate_reward_list = final_reward_list

        # ---------= [Clean Up] =---------
        del noise_pred, timestep
        # del alpha_prod_t, alpha_prod_t_prev, beta_prod_t
        del alpha_prod_t, beta_prod_t
        del final_reward_list
        if self._more_param_op == "sdxl":
            del self._param_dict["add_text_embeds"]
            del self._param_dict["add_time_ids"]
        elif self._more_param_op == "pixart_alpha":
            del self._param_dict["prompt_attention_mask"]
            del self._param_dict["added_cond_kwargs"]
        gc.collect()
        torch.cuda.empty_cache()

        # `_cal_intermediate_reward_immediate()` done
        return (
            intermediate_reward_list, 
            pseudo_final_latent_list
        )


    @torch.no_grad()
    def _get_look_ahead_timestep_list(
        self, 

        timestep_list: torch.Tensor, 

        num_look_ahead_step: Optional[int] = 2
    ) -> torch.Tensor:
        """
        Func:
            Get the look-ahead results at the current timestep from `timestep_list` by uniformly 
                interpolating it to `num_look_ahead_step + 1` steps, starting from `0`. 
            Round down non-integer results. 

        Ret:
            `look_ahead_timestep_list_list` (`torch.Tensor`): The list of the look-ahead lists. 
                look_ahead_timestep_list_list.shape = (num_latent, num_look_ahead_step + 1). 
        """

        # look_ahead_timestep_list_list.shape = (num_latent, num_look_ahead_step + 1)
        look_ahead_timestep_list_list = []

        for timestep in timestep_list:
            # look_ahead_timestep_list.shape = (num_look_ahead_step + 1, )
            look_ahead_timestep_list = torch.linspace(
                timestep, 0, 
                num_look_ahead_step + 1
            )

            look_ahead_timestep_list = torch.floor(look_ahead_timestep_list)

            look_ahead_timestep_list = look_ahead_timestep_list.to(
                dtype = torch.int, 
                device = self._device
            )

            look_ahead_timestep_list_list.append(look_ahead_timestep_list)

            # goto `for timestep`
            pass

        # look_ahead_timestep_list_list.shape = (num_latent, num_look_ahead_step + 1)
        look_ahead_timestep_list_list = torch.stack(look_ahead_timestep_list_list)

        # `_get_look_ahead_timestep_list()` done
        return look_ahead_timestep_list_list


    @torch.no_grad()
    def _cal_intermediate_reward_look_ahead(
        self, 

        latent_list: Optional[torch.Tensor] = None, 

        prompt_list: Union[str, List[str]] = None,

        timestep_list: Optional[torch.Tensor] = None, 

        num_look_ahead_step: int = None, 

        **arg_dict: Optional[Dict]
    ) -> Tuple[
        torch.Tensor, 
        torch.Tensor
    ]:
        """
        Func:
            Compute the intermediate reward for the intermediate latents, 
                using the `look_ahead` policy. 

        Ret:
            `intermediate_reward_list` (`torch.Tensor`): The list of the intermediate rewards. 
                intermediate_reward_list.shape = (num_latent, *self._reward_shape). 
            `pseudo_final_latent_list` (`torch.Tensor`): 
                The list of the predicted final latents used for computing intermediate rewards. 
                    pseudo_final_latent_list.shape = (num_latent, *self._state_shape). 
        """

        sample_idx_list = arg_dict.get("sample_idx_list", None)

        if sample_idx_list is None:
            raise ValueError(
                f"`sample_idx_list` must be provided. "
            )

        num_latent = len(latent_list)

        # ---------= [Get Look-ahead Timestep List] =---------
        # look_ahead_timestep_list_list.shape = (num_latent, num_look_ahead_step + 1)
        look_ahead_timestep_list_list = self._get_look_ahead_timestep_list(
            timestep_list = timestep_list, 

            num_look_ahead_step = num_look_ahead_step
        )
        
        # ---------= [Look-ahead] =---------
        # prompt_emb_list = self._get_prompt_emb_list(
        #     sample_idx_list = sample_idx_list
        # )

        prompt_emb_list = self._get_concat_param_list(
            param_list = self._prompt_emb_list, 
            sample_idx_list = sample_idx_list
        )

        if self._more_param_op == "sdxl":
            add_text_emb_list = self._get_concat_param_list(
                param_list = self._add_text_emb_list, 
                sample_idx_list = sample_idx_list 
            )
            add_time_id_list = self._get_concat_param_list(
                param_list = self._add_time_id_list, 
                sample_idx_list = sample_idx_list 
            )

            self._param_dict["add_text_embeds"] = add_text_emb_list
            self._param_dict["add_time_ids"] = add_time_id_list
        elif self._more_param_op == "pixart_alpha":
            prompt_attention_mask_list = self._get_concat_param_list(
                param_list = self._prompt_attention_mask_list, 
                sample_idx_list = sample_idx_list 
            )

            resolution_list = self._get_concat_param_list(
                param_list = self._resolution_list, 
                sample_idx_list = sample_idx_list 
            )
            aspect_ratio_list = self._get_concat_param_list(
                param_list = self._aspect_ratio_list, 
                sample_idx_list = sample_idx_list 
            )
            added_cond_kwargs = {
                "resolution": resolution_list, 
                "aspect_ratio": aspect_ratio_list
            }

            self._param_dict["prompt_attention_mask"] = prompt_attention_mask_list
            self._param_dict["added_cond_kwargs"] = added_cond_kwargs
        
        tmp_latent_list = latent_list.clone()

        for timestep_idx in range(num_look_ahead_step - 1):
            batch_cur_timestep = look_ahead_timestep_list_list[:, timestep_idx]
            batch_prev_timestep = look_ahead_timestep_list_list[:, timestep_idx + 1]
            
            (
                # noise_pred.shape = (true_batch_size, 4, latent_height, latent_width)
                batch_noise_pred, 

                # batch_timestep.shape = (true_batch_size, )
                batch_timestep
            ) = self._pipeline.get_noise_pred(
                param_dict = self._param_dict, 
                prompt_emb_list = prompt_emb_list, 

                latent_list = tmp_latent_list, 
                
                timestep_list = batch_cur_timestep
            )
            
            # tmp_latent_list.shape = (true_batch_size, 4, latent_height, latent_width)
            tmp_latent_list = self._pipeline.step(
                latent_list = tmp_latent_list, 

                noise_pred = batch_noise_pred, 
                
                timestep = batch_cur_timestep, 
                prev_timestep = batch_prev_timestep, 

                eta_list = 0.0, 
                eps = None
            )

            # ---------= [Clean Up] =---------
            del batch_cur_timestep, batch_prev_timestep
            del batch_noise_pred, batch_timestep
            gc.collect()
            torch.cuda.empty_cache()

            # goto `for timestep_idx`
            pass

        pseudo_final_latent_list = tmp_latent_list

        final_reward_list = self.cal_final_reward(
            # sample_idx_list = sample_idx_list,  # included in `arg_dict`

            latent_list = pseudo_final_latent_list, 

            prompt_list = prompt_list, 

            **arg_dict
        )

        intermediate_reward_list = final_reward_list

        # ---------= [Clean Up] =---------
        del look_ahead_timestep_list_list
        del final_reward_list
        if self._more_param_op == "sdxl":
            del self._param_dict["add_text_embeds"]
            del self._param_dict["add_time_ids"]
        elif self._more_param_op == "pixart_alpha":
            del self._param_dict["prompt_attention_mask"]
            del self._param_dict["added_cond_kwargs"]
        gc.collect()
        torch.cuda.empty_cache()

        # `_cal_intermediate_reward_look_ahead()` done
        return (
            intermediate_reward_list, 
            pseudo_final_latent_list
        )


    @torch.no_grad()
    def _cal_intermediate_reward_sequential(
        self, 

        latent_list: torch.Tensor = None, 

        prompt_list: Union[str, List[str]] = None,

        timestep_idx_list: torch.Tensor = None, 

        **arg_dict: Optional[Dict]
    ) -> Tuple[
        torch.Tensor, 
        torch.Tensor
    ]:
        """
        Func:
            Compute the intermediate reward for the intermediate latents, 
                using the `sequential` policy. 

        Ret:
            `intermediate_reward_list` (`torch.Tensor`): The list of the intermediate rewards. 
                intermediate_reward_list.shape = (num_latent, *self._reward_shape). 
            `pseudo_final_latent_list` (`torch.Tensor`): 
                The list of the predicted final latents used for computing intermediate rewards. 
                    pseudo_final_latent_list.shape = (num_latent, *self._state_shape). 
        """

        sample_idx_list = arg_dict.get("sample_idx_list", None)

        if sample_idx_list is None:
            raise ValueError(
                f"`sample_idx_list` must be provided. "
            )

        num_latent = len(latent_list)

        # ---------= [Sequential] =---------
        tmp_latent_list = latent_list.clone()
        cur_timestep_idx_list = timestep_idx_list.clone()

        while True:
            need_cal_idx_list = [
                i \
                    for i, cur_timestep_idx in enumerate(cur_timestep_idx_list) \
                        if cur_timestep_idx < self._num_inference_step
            ]

            if len(need_cal_idx_list) <= 0:
                break

            need_cal_idx_list = torch.tensor(need_cal_idx_list)
            batch_latent_list = latent_list[need_cal_idx_list]
            batch_timestep_idx_list = cur_timestep_idx_list[need_cal_idx_list]

            # batch_latent_list.shape = (true_batch_size, 4, latent_height, latent_width)
            batch_latent_list = self.batch_cal_dynamics(
                # sample_idx_list = sample_idx_list,  # included in `arg_dict`

                latent_list = batch_latent_list, 

                timestep_idx_list = batch_timestep_idx_list, 

                eta_list = 0.0, 
                eps_list = None, 

                **arg_dict
            )

            tmp_latent_list[need_cal_idx_list] = batch_latent_list
            cur_timestep_idx_list[need_cal_idx_list] += 1

            # ---------= [Clean Up] =---------
            del need_cal_idx_list
            del batch_latent_list
            del batch_timestep_idx_list
            gc.collect()
            torch.cuda.empty_cache()

            # goto `for timestep_idx`
            pass

        pseudo_final_latent_list = tmp_latent_list

        final_reward_list = self.cal_final_reward(
            # sample_idx_list = sample_idx_list,  # included in `arg_dict`
            
            latent_list = pseudo_final_latent_list, 

            prompt_list = prompt_list, 

            **arg_dict
        )

        intermediate_reward_list = final_reward_list

        # ---------= [Clean Up] =---------
        del tmp_latent_list, cur_timestep_idx_list
        gc.collect()
        torch.cuda.empty_cache()

        # `_cal_intermediate_reward_sequential()` done
        return (
            intermediate_reward_list, 
            pseudo_final_latent_list
        )


    @torch.no_grad()
    def _cal_discounted_intermediate_reward_list(
        self, 

        # intermediate_reward_list.shape = (batch_size, *reward_shape)
        intermediate_reward_list: torch.Tensor = None, 

        # timestep_idx_list.shape = (batch_size, 1)
        timestep_idx_list: torch.Tensor = None, 

        gamma: float = None, 

        **arg_dict: Optional[Dict]
    ) -> torch.Tensor:
        """
        Func:
            Compute the discounted intermediate reward list. 

        Ret:
            `discounted_intermediate_reward_list` (`torch.Tensor`): The list of the discounted intermediate rewards. 
                discounted_intermediate_reward_list.shape = (num_latent, *self._reward_shape). 
        """
        
        discount_factor_list = [
            torch.pow(
                gamma, 
                self._num_inference_step - timestep_idx
            ).to(
                dtype = intermediate_reward_list.dtype, 
                device = intermediate_reward_list.device
            ) \
                for timestep_idx in timestep_idx_list
        ]

        # discount_factor_list.shape: (batch_size, ) -> (batch_size, *reward_shape)
        discount_factor_list = torch.stack(discount_factor_list)
        discount_factor_list = discount_factor_list.reshape(
            (discount_factor_list.shape[0], *self._reward_shape)
        )

        discounted_intermediate_reward_list = intermediate_reward_list * discount_factor_list

        # `_cal_discounted_intermediate_reward_list()` done
        return discounted_intermediate_reward_list
    
