from util.logger import logger

from typing import Optional, Dict, Tuple, List

import torch


def get_reward_model(
    reward_model_type: str,  # ["hps_v2", ]

    # ---------= [Pipeline] =---------
    pipeline: "StableDiffusionPipeline" = None, 
    num_inference_step: int = None, 

    # ---------= [Param] =---------
    prompt_emb_list: List[torch.Tensor] = None, 
    param_dict: Dict = None, 
    num_sample_per_prompt: int = None, 

    # ---------= [Reward] =---------
    reward_shape: Tuple = (1, ), 
    reward_dtype: str = "float32", 
    offload_to_cpu: Optional[bool] = True, 

    # ---------= [Parallel] =---------
    cal_dynamics_batch_size: int = 1, 
    cal_intermediate_reward_batch_size: int = 1, 
    cal_final_reward_batch_size: int = 1, 

    # ---------= [Reward Shaping] =---------
    # `reward_shaping_policy` in [
    #     "disabled", 
    #     "latent_reward", 
    #     "potential_based", 
    #     "skipping"
    # ]
    reward_shaping_policy: str = "disabled", 
    # `cal_intermediate_reward_policy` in [
    #     "immediate_posterior_mean", 
    #     "immediate_score_function", 
    #     "look_ahead", 
    #     "sequential", 
    #     "discount", 
    # ]
    # potential_exp_growing: bool = False, 
    # potential_exp_base: float = 1.0, 
    cal_intermediate_reward_policy: str = "immediate_posterior_mean", 

    device: Optional[str] = "cpu", 

    vae_decode_batch_size: Optional[int] = 10, 

    **arg_dict: Dict
) -> "RewardModel":
    """
    Func:
        Get a reward model of type `reward_model_type`. 

    Ret:
        `reward_model` (`RewardModel`): The derived reward model. 
    """

    if reward_model_type == "color_channel_reward":
        from .color_channel_reward import ColorChannel_RewardModel as RewardModel
    elif reward_model_type == "laplacian_var_reward":
        from .laplacian_var_reward import LaplacianVariance_RewardModel as RewardModel
    elif reward_model_type == "clip_score":
        from .clip_score import CLIPScore_RewardModel as RewardModel
    elif reward_model_type in [
        "compressibility_reward", 
        "incompressibility_reward", 
    ]:
        from .compressibility_reward import Compressibility_RewardModel as RewardModel

        if reward_model_type == "compressibility_reward":
            arg_dict["inv"] = True
        else:
            arg_dict["inv"] = False

    elif reward_model_type == "hps_v2":
        from .hps_v2 import HumanPreferenceScore_v2_RewardModel as RewardModel

    elif reward_model_type == "image_reward":
        from .image_reward import ImageReward_RewardModel as RewardModel

    elif reward_model_type == "pick_score":
        from .pick_score import PickScore_RewardModel as RewardModel

    elif reward_model_type in [
        "compressibility_hps_v2", 
        "incompressibility_hps_v2", 
    ]:
        from .compressibility_hps_v2 import Compressibility_HPS_v2_RewardModel as RewardModel

        if reward_model_type == "compressibility_hps_v2":
            arg_dict["inv"] = True
        else:
            arg_dict["inv"] = False

    else:
        raise NotImplementedError(
            f"Unsupported `reward_model_type`, got `{reward_model_type}`. "
        )

    reward_model = RewardModel(
        # ---------= [Pipeline] =---------
        pipeline = pipeline, 
        num_inference_step = num_inference_step, 

        # ---------= [Param] =---------
        prompt_emb_list = prompt_emb_list, 
        param_dict = param_dict, 
        num_sample_per_prompt = num_sample_per_prompt, 

        # ---------= [Reward] =---------
        reward_shape = reward_shape, 
        reward_dtype = reward_dtype, 
        offload_to_cpu = offload_to_cpu, 

        # ---------= [Parallel] =---------
        cal_dynamics_batch_size = cal_dynamics_batch_size, 
        cal_intermediate_reward_batch_size = cal_intermediate_reward_batch_size, 
        cal_final_reward_batch_size = cal_final_reward_batch_size, 

        # ---------= [Reward Shaping] =---------
        reward_shaping_policy = reward_shaping_policy, 
        # potential_exp_growing = potential_exp_growing, 
        # potential_exp_base = potential_exp_base, 
        cal_intermediate_reward_policy = cal_intermediate_reward_policy, 

        device = device, 
        
        vae_decode_batch_size = vae_decode_batch_size, 

        **arg_dict
    )

    # `get_reward_model()` done
    return reward_model
