# defaults
DEFAULT_GPU_ID=0

# GPU ID
gpu_id=${1:-$DEFAULT_GPU_ID}    
export CUDA_VISIBLE_DEVICES=${gpu_id}

DEFAULT_NFE=30
DEFAULT_EXPLORATION_COEF=3.0

default_nfe=${2:-$DEFAULT_NFE}
exploration_coef=${3:-$DEFAULT_EXPLORATION_COEF}

GLOBAL_SEED=42
global_seed=${4:-$GLOBAL_SEED}


# (test only) 测 Beta soft
# python main.py \
#     exp_name.default_exp_name="exp" \
#     vae_decode_batch_size=30 \
#     seed=${global_seed} \
#     pipeline=sd_v1_4 \
#     task=search/run_optimal_control_mcts/sd_v1_4/template \
#     task.init_latent.seed_list=0 \
#     task.eps.seed_list=3072 \
#     task.prompt_list.num_prompt=5 \
#     task.prompt_list.prompt_manager_dict.prompt_manager_type="HumanPreferenceDataset_v2" \
#     task.prompt_list.prompt_manager_dict.cfg_yaml_path="./config/dataset/hpd_v2_20.yaml" \
#     task.task.num_sample_per_prompt=2 \
#     task.sample.height=512 \
#     task.sample.width=512 \
#     task.sample.num_inference_step=15 \
#     task.reward_model.reward_model_type="hps_v2" \
#     task.reward_model.cal_dynamics_batch_size=60 \
#     task.reward_model.cal_final_reward_batch_size=60 \
#     task.reward_model.reward_shaping_policy="latent_reward" \
#     task.mcts.mode.mdp_modeling="cumulative_reward" \
#     task.mcts.mode.value_policy="max" \
#     task.mcts.mode.pseudo_latent_as_final=False \
#     task.mcts.ucb.exploration_coef=${exploration_coef} \
#     task.mcts.selection.selection_depth_lim=None \
#     task.mcts.expansion.expansion_action_sampling_policy="beta" \
#     task.mcts.beta.update_policy="value_gradient" \
#     task.mcts.beta.value_gradient_update_time="back_propagation" \
#     task.mcts.beta.update_step_size=0.1 \
#     task.mcts.beta.max_update_bias=1.0 \
#     task.mcts.beta.zeta_list=10 \
#     task.mcts.nfe_limit.nfe_cal_dynamics_lim=${default_nfe} \
#     task.mcts.nfe_limit.nfe_cal_intermediate_reward_lim=1e9 \
#     task.mcts.nfe_limit.nfe_cal_final_reward_lim=1e9 \


# python main.py \
#     exp_name.default_exp_name="exp" \
#     vae_decode_batch_size=30 \
#     seed=${global_seed} \
#     pipeline=sd_v1_4 \
#     task=search/run_optimal_control_mcts/sd_v1_4/template \
#     task.init_latent.seed_list=0 \
#     task.eps.seed_list=3072 \
#     task.prompt_list.num_prompt=5 \
#     task.prompt_list.prompt_manager_dict.prompt_manager_type="HumanPreferenceDataset_v2" \
#     task.prompt_list.prompt_manager_dict.cfg_yaml_path="./config/dataset/hpd_v2_20.yaml" \
#     task.task.num_sample_per_prompt=2 \
#     task.sample.height=512 \
#     task.sample.width=512 \
#     task.sample.num_inference_step=15 \
#     task.reward_model.reward_model_type="hps_v2" \
#     task.reward_model.cal_dynamics_batch_size=60 \
#     task.reward_model.cal_final_reward_batch_size=60 \
#     task.reward_model.reward_shaping_policy="latent_reward" \
#     task.mcts.mode.mdp_modeling="max_reward" \
#     task.mcts.mode.value_policy="max" \
#     task.mcts.mode.pseudo_latent_as_final=True \
#     task.mcts.ucb.exploration_coef=${exploration_coef} \
#     task.mcts.selection.selection_depth_lim=None \
#     task.mcts.expansion.expansion_action_sampling_policy="beta" \
#     task.mcts.beta.update_policy="value_gradient" \
#     task.mcts.beta.value_gradient_update_time="best_trajectory_updated" \
#     task.mcts.beta.update_step_size=0.1 \
#     task.mcts.beta.max_update_bias=1.0 \
#     task.mcts.beta.zeta_list=10 \
#     task.mcts.nfe_limit.nfe_cal_dynamics_lim=${default_nfe} \
#     task.mcts.nfe_limit.nfe_cal_intermediate_reward_lim=1e9 \
#     task.mcts.nfe_limit.nfe_cal_final_reward_lim=1e9 \


# python main.py \
#     exp_name.default_exp_name="exp" \
#     vae_decode_batch_size=30 \
#     seed=${global_seed} \
#     pipeline=sd_v1_4 \
#     task=search/run_optimal_control_mcts/sd_v1_4/template \
#     task.init_latent.seed_list=0 \
#     task.eps.seed_list=3072 \
#     task.prompt_list.num_prompt=5 \
#     task.prompt_list.prompt_manager_dict.prompt_manager_type="HumanPreferenceDataset_v2" \
#     task.prompt_list.prompt_manager_dict.cfg_yaml_path="./config/dataset/hpd_v2_20.yaml" \
#     task.task.num_sample_per_prompt=2 \
#     task.sample.height=512 \
#     task.sample.width=512 \
#     task.sample.num_inference_step=15 \
#     task.reward_model.reward_model_type="hps_v2" \
#     task.reward_model.cal_dynamics_batch_size=60 \
#     task.reward_model.cal_final_reward_batch_size=60 \
#     task.reward_model.reward_shaping_policy="latent_reward" \
#     task.mcts.mode.mdp_modeling="max_reward" \
#     task.mcts.mode.value_policy="max" \
#     task.mcts.mode.pseudo_latent_as_final=True \
#     task.mcts.ucb.exploration_coef=${exploration_coef} \
#     task.mcts.selection.selection_depth_lim=8 \
#     task.mcts.expansion.expansion_action_sampling_policy="beta" \
#     task.mcts.beta.update_policy="value_gradient" \
#     task.mcts.beta.value_gradient_update_time="best_trajectory_updated" \
#     task.mcts.beta.update_step_size=0.1 \
#     task.mcts.beta.max_update_bias=1.0 \
#     task.mcts.beta.zeta_list=10 \
#     task.mcts.nfe_limit.nfe_cal_dynamics_lim=${default_nfe} \
#     task.mcts.nfe_limit.nfe_cal_intermediate_reward_lim=1e9 \
#     task.mcts.nfe_limit.nfe_cal_final_reward_lim=1e9 \


# latent, max, average, 5120
python main.py \
    exp_name.default_exp_name="max_average" \
    vae_decode_batch_size=10 \
    seed=${global_seed} \
    pipeline=sd_v1_4 \
    task=search/run_optimal_control_mcts/sd_v1_4/template \
    task.init_latent.seed_list=0 \
    task.eps.seed_list=5120 \
    task.prompt_list.num_prompt=20 \
    task.prompt_list.prompt_manager_dict.prompt_manager_type="HumanPreferenceDataset_v2" \
    task.prompt_list.prompt_manager_dict.cfg_yaml_path="./config/dataset/hpd_v2_30.yaml" \
    task.task.num_sample_per_prompt=2 \
    task.sample.height=512 \
    task.sample.width=512 \
    task.sample.num_inference_step=15 \
    task.reward_model.reward_model_type="hps_v2" \
    task.reward_model.cal_dynamics_batch_size=40 \
    task.reward_model.cal_final_reward_batch_size=40 \
    task.reward_model.reward_shaping_policy="latent_reward" \
    task.mcts.mode.mdp_modeling="max_reward" \
    task.mcts.mode.value_policy="average" \
    task.mcts.mode.pseudo_latent_as_final=True \
    task.mcts.ucb.exploration_coef=${exploration_coef} \
    task.mcts.selection.selection_depth_lim=None \
    task.mcts.expansion.expansion_action_sampling_policy="beta" \
    task.mcts.beta.update_policy="soft" \
    task.mcts.beta.update_step_size=0.1 \
    task.mcts.beta.max_update_bias=1.0 \
    task.mcts.beta.zeta_list=10 \
    task.mcts.nfe_limit.nfe_cal_dynamics_lim=${default_nfe} \
    task.mcts.nfe_limit.nfe_cal_intermediate_reward_lim=1e9 \
    task.mcts.nfe_limit.nfe_cal_final_reward_lim=1e9 \

