#!/usr/bin/env bash
# Generate Coefficient Sensitivity Experiments
set -euo pipefail

# ========================
# Configuration
# ========================
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
BASE_SCRIPT="${SCRIPT_DIR}/ablation_kl_both_conditioned.sh"

# KL-in-Reward coefficient combinations
REWARD_COEF_COMBINATIONS=(
    "0.3,1.0"
    "0.7,1.0"
    "0.5,0.8"
    "0.5,1.2"
)

# KL-in-Loss coefficient combinations
LOSS_COEF_COMBINATIONS=(
    "0.0005,0.3,1.0"
    "0.002,0.7,1.0"
    "0.001,0.5,0.8"
    "0.001,0.5,1.2"
)

# ========================
# Functions
# ========================
generate_reward_sensitivity_script() {
    local hard_coef=$1
    local nonhard_coef=$2
    local script_name="ablation_kl_reward_coef_sensitivity_${hard_coef}_${nonhard_coef}.sh"
    local script_path="${SCRIPT_DIR}/${script_name}"
    
    echo "Generating: ${script_name}"
    
    cat > "${script_path}" << EOF
#!/usr/bin/env bash
# Ablation Experiment: KL-in-Reward Coefficient Sensitivity (${hard_coef}, ${nonhard_coef})
set -euo pipefail

# ========================
# Environment and system
# ========================
export HYDRA_FULL_ERROR=1
export CUDA_VISIBLE_DEVICES=\${CUDA_VISIBLE_DEVICES:-"0,1,2,3,4,5,6,7"}

export NCCL_P2P_DISABLE=0
export NCCL_SHM_DISABLE=0

# Project and experiment configuration
project_name='AIME2024-Qwen3-8B'
timestamp=\$(date +%Y%m%d_%H%M%S)
experiment_name="aime2024_ablation_kl_reward_coef_${hard_coef}_${nonhard_coef}_\${timestamp}"

# Data paths
train_file="/primus_datasets/primus_data/clpo_SKYRTP/DAPO-Math-17k/data/dapo-math-17k.parquet"
val_file="/primus_datasets/primus_data/aime_2B4pCq/train-00000-of-00001-fixed.parquet"

# Model path
model_path="/primus_datasets/primus_data/Qwen3_rNrLUi/Qwen3-8B"

# Output directory
output_dir="/primus_oss/_checkpoint/0910-Qwen3-8B-AIME2024/ablation_kl_reward_coef_${hard_coef}_${nonhard_coef}"

# Data configuration
max_prompt_length=2048
max_response_length=8192
train_batch_size=64
val_batch_size=32
truncation="error"
filter_overlong_prompts=true
dataloader_num_workers=4

# Algorithm configuration - KL-in-Reward Coefficient Sensitivity
adv_estimator=grpo
use_kl_in_reward=true
kl_in_reward_coef_hard=${hard_coef}
kl_in_reward_coef_nonhard=${nonhard_coef}
use_kl_loss=true
kl_loss_coef=0.001
kl_loss_coef_hard_scale=0.5
kl_loss_coef_nonhard_scale=1.0

# Model configuration
enable_gradient_checkpointing=true
use_remove_padding=true

# Actor configuration
actor_lr=1e-6
actor_lr_warmup_steps=10
warmup_style=constant
ppo_mini_batch_size=8
ppo_micro_batch_size_per_gpu=1
entropy_coeff=0
param_offload=false
optimizer_offload=false

# Rollout configuration
rollout_name=vllm
n_resp_per_prompt=4
tensor_model_parallel_size=1
gpu_memory_utilization=0.5
log_prob_micro_batch_size_per_gpu=1
max_model_len=10240
max_num_batched_tokens=10240

# Trainer configuration
total_epochs=1
critic_warmup=0
test_freq=10
save_freq=50
val_before_train=true
ngpus_per_node=\${NGPUS_PER_NODE:-8}
nnodes=\${NNODES:-1}

# CLPO specific configuration
clpo_hard_acc_upper=0.3
clpo_med_acc_lower=0.3
clpo_med_acc_upper=0.7

# CLPO rewrite data saving configuration
clpo_save_rewrite_data=true
clpo_rewrite_save_path="\${output_dir}/rewrite_data.json"
clpo_hard_rewrite_save_path="\${output_dir}/hard_rewrite_data.json"
clpo_medium_rewrite_save_path="\${output_dir}/medium_rewrite_data.json"

echo "==== ABLATION EXPERIMENT: KL-in-Reward Coefficient Sensitivity (${hard_coef}, ${nonhard_coef}) ===="
echo "Project: \$project_name"
echo "Experiment: \$experiment_name"
echo "Train Data: \$train_file"
echo "Val Data: \$val_file"
echo "Model Path: \$model_path"
echo "Output Dir: \$output_dir"
echo "KL-in-Reward: \$use_kl_in_reward"
echo "KL-in-Reward Coef Hard: \$kl_in_reward_coef_hard"
echo "KL-in-Reward Coef Non-hard: \$kl_in_reward_coef_nonhard"
echo "KL-in-Loss: \$use_kl_loss"


python3 -m verl.trainer.main_ppo \\
  algorithm.adv_estimator="\${adv_estimator}" \\
  algorithm.norm_adv_by_std_in_grpo=true \\
  algorithm.use_kl_in_reward="\${use_kl_in_reward}" \\
  algorithm.kl_in_reward_coef_hard="\${kl_in_reward_coef_hard}" \\
  algorithm.kl_in_reward_coef_nonhard="\${kl_in_reward_coef_nonhard}" \\
  data.train_files="\${train_file}" \\
  data.val_files="\${val_file}" \\
  data.train_batch_size="\${train_batch_size}" \\
  data.val_batch_size="\${val_batch_size}" \\
  data.max_prompt_length="\${max_prompt_length}" \\
  data.max_response_length="\${max_response_length}" \\
  data.filter_overlong_prompts="\${filter_overlong_prompts}" \\
  data.truncation="\${truncation}" \\
  data.dataloader_num_workers="\${dataloader_num_workers}" \\
  data.clpo_hard_acc_upper="\${clpo_hard_acc_upper}" \\
  data.clpo_medium_acc_lower="\${clpo_med_acc_lower}" \\
  data.clpo_medium_acc_upper="\${clpo_med_acc_upper}" \\
  data.clpo_save_rewrite_data="\${clpo_save_rewrite_data}" \\
  data.clpo_rewrite_save_path="\${clpo_rewrite_save_path}" \\
  data.clpo_hard_rewrite_save_path="\${clpo_hard_rewrite_save_path}" \\
  data.clpo_medium_rewrite_save_path="\${clpo_medium_rewrite_save_path}" \\
  actor_rollout_ref.model.path="\${model_path}" \\
  actor_rollout_ref.model.enable_gradient_checkpointing="\${enable_gradient_checkpointing}" \\
  actor_rollout_ref.model.use_remove_padding="\${use_remove_padding}" \\
  actor_rollout_ref.actor.optim.lr="\${actor_lr}" \\
  actor_rollout_ref.actor.optim.lr_warmup_steps="\${actor_lr_warmup_steps}" \\
  actor_rollout_ref.actor.optim.warmup_style="\${warmup_style}" \\
  actor_rollout_ref.actor.ppo_mini_batch_size="\${ppo_mini_batch_size}" \\
  actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu="\${ppo_micro_batch_size_per_gpu}" \\
  actor_rollout_ref.actor.use_kl_loss="\${use_kl_loss}" \\
  actor_rollout_ref.actor.kl_loss_coef="\${kl_loss_coef}" \\
  actor_rollout_ref.actor.kl_loss_coef_hard_scale="\${kl_loss_coef_hard_scale}" \\
  actor_rollout_ref.actor.kl_loss_coef_nonhard_scale="\${kl_loss_coef_nonhard_scale}" \\
  actor_rollout_ref.actor.entropy_coeff="\${entropy_coeff}" \\
  actor_rollout_ref.actor.fsdp_config.param_offload="\${param_offload}" \\
  actor_rollout_ref.actor.fsdp_config.optimizer_offload="\${optimizer_offload}" \\
  actor_rollout_ref.rollout.name="\${rollout_name}" \\
  actor_rollout_ref.rollout.n="\${n_resp_per_prompt}" \\
  actor_rollout_ref.rollout.tensor_model_parallel_size="\${tensor_model_parallel_size}" \\
  actor_rollout_ref.rollout.gpu_memory_utilization="\${gpu_memory_utilization}" \\
  actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu="\${log_prob_micro_batch_size_per_gpu}" \\
  actor_rollout_ref.rollout.max_model_len="\${max_model_len}" \\
  actor_rollout_ref.rollout.max_num_batched_tokens="\${max_num_batched_tokens}" \\
  actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu="\${log_prob_micro_batch_size_per_gpu}" \\
  actor_rollout_ref.ref.fsdp_config.param_offload="\${param_offload}" \\
  trainer.logger='["console", "swanlab"]' \\
  trainer.project_name="\${project_name}" \\
  trainer.experiment_name="\${experiment_name}" \\
  trainer.total_epochs="\${total_epochs}" \\
  trainer.critic_warmup="\${critic_warmup}" \\
  trainer.test_freq="\${test_freq}" \\
  trainer.save_freq="\${save_freq}" \\
  trainer.val_before_train="\${val_before_train}" \\
  trainer.n_gpus_per_node="\${ngpus_per_node}" \\
  trainer.nnodes="\${nnodes}" \\
  trainer.default_local_dir="\${output_dir}" \\
  trainer.task=clpo \\
  "\$@"

training_exit_code=\$?

echo ""
echo "🏁 ===== ABLATION EXPERIMENT COMPLETED ====="

if [ \$training_exit_code -eq 0 ]; then
    echo "✅ KL-in-Reward Coefficient Sensitivity (${hard_coef}, ${nonhard_coef}) experiment completed successfully!"
    echo "📁 Checkpoints saved to: \${output_dir}"
    echo "📊 Experiment name: \${experiment_name}"
else
    echo "❌ Experiment failed with exit code: \$training_exit_code"
fi

exit \$training_exit_code
EOF

    chmod +x "${script_path}"
}

generate_loss_sensitivity_script() {
    local loss_coef=$1
    local hard_scale=$2
    local nonhard_scale=$3
    local script_name="ablation_kl_loss_coef_sensitivity_${loss_coef}_${hard_scale}_${nonhard_scale}.sh"
    local script_path="${SCRIPT_DIR}/${script_name}"
    
    echo "Generating: ${script_name}"
    
    cat > "${script_path}" << EOF
#!/usr/bin/env bash
# Ablation Experiment: KL-in-Loss Coefficient Sensitivity (${loss_coef}, ${hard_scale}, ${nonhard_scale})
set -euo pipefail

# ========================
# Environment and system
# ========================
export HYDRA_FULL_ERROR=1
export CUDA_VISIBLE_DEVICES=\${CUDA_VISIBLE_DEVICES:-"0,1,2,3,4,5,6,7"}

export NCCL_P2P_DISABLE=0
export NCCL_SHM_DISABLE=0

# Project and experiment configuration
project_name='AIME2024-Qwen3-8B'
timestamp=\$(date +%Y%m%d_%H%M%S)
experiment_name="aime2024_ablation_kl_loss_coef_${loss_coef}_${hard_scale}_${nonhard_scale}_\${timestamp}"

# Data paths
train_file="/primus_datasets/primus_data/clpo_SKYRTP/DAPO-Math-17k/data/dapo-math-17k.parquet"
val_file="/primus_datasets/primus_data/aime_2B4pCq/train-00000-of-00001-fixed.parquet"

# Model path
model_path="/primus_datasets/primus_data/Qwen3_rNrLUi/Qwen3-8B"

# Output directory
output_dir="/primus_oss/_checkpoint/0910-Qwen3-8B-AIME2024/ablation_kl_loss_coef_${loss_coef}_${hard_scale}_${nonhard_scale}"

# Data configuration
max_prompt_length=2048
max_response_length=8192
train_batch_size=64
val_batch_size=32
truncation="error"
filter_overlong_prompts=true
dataloader_num_workers=4

# Algorithm configuration - KL-in-Loss Coefficient Sensitivity
adv_estimator=grpo
use_kl_in_reward=true
kl_in_reward_coef_hard=0.5
kl_in_reward_coef_nonhard=1.0
use_kl_loss=true
kl_loss_coef=${loss_coef}
kl_loss_coef_hard_scale=${hard_scale}
kl_loss_coef_nonhard_scale=${nonhard_scale}

# Model configuration
enable_gradient_checkpointing=true
use_remove_padding=true

# Actor configuration
actor_lr=1e-6
actor_lr_warmup_steps=10
warmup_style=constant
ppo_mini_batch_size=8
ppo_micro_batch_size_per_gpu=1
entropy_coeff=0
param_offload=false
optimizer_offload=false

# Rollout configuration
rollout_name=vllm
n_resp_per_prompt=4
tensor_model_parallel_size=1
gpu_memory_utilization=0.5
log_prob_micro_batch_size_per_gpu=1
max_model_len=10240
max_num_batched_tokens=10240

# Trainer configuration
total_epochs=1
critic_warmup=0
test_freq=10
save_freq=50
val_before_train=true
ngpus_per_node=\${NGPUS_PER_NODE:-8}
nnodes=\${NNODES:-1}

# CLPO specific configuration
clpo_hard_acc_upper=0.3
clpo_med_acc_lower=0.3
clpo_med_acc_upper=0.7

# CLPO rewrite data saving configuration
clpo_save_rewrite_data=true
clpo_rewrite_save_path="\${output_dir}/rewrite_data.json"
clpo_hard_rewrite_save_path="\${output_dir}/hard_rewrite_data.json"
clpo_medium_rewrite_save_path="\${output_dir}/medium_rewrite_data.json"

echo "==== ABLATION EXPERIMENT: KL-in-Loss Coefficient Sensitivity (${loss_coef}, ${hard_scale}, ${nonhard_scale}) ===="
echo "Project: \$project_name"
echo "Experiment: \$experiment_name"
echo "Train Data: \$train_file"
echo "Val Data: \$val_file"
echo "Model Path: \$model_path"
echo "Output Dir: \$output_dir"
echo "KL-in-Reward: \$use_kl_in_reward"
echo "KL-in-Loss: \$use_kl_loss"
echo "KL Loss Coef: \$kl_loss_coef"
echo "KL Loss Coef Hard Scale: \$kl_loss_coef_hard_scale"
echo "KL Loss Coef Non-hard Scale: \$kl_loss_coef_nonhard_scale"


python3 -m verl.trainer.main_ppo \\
  algorithm.adv_estimator="\${adv_estimator}" \\
  algorithm.norm_adv_by_std_in_grpo=true \\
  algorithm.use_kl_in_reward="\${use_kl_in_reward}" \\
  algorithm.kl_in_reward_coef_hard="\${kl_in_reward_coef_hard}" \\
  algorithm.kl_in_reward_coef_nonhard="\${kl_in_reward_coef_nonhard}" \\
  data.train_files="\${train_file}" \\
  data.val_files="\${val_file}" \\
  data.train_batch_size="\${train_batch_size}" \\
  data.val_batch_size="\${val_batch_size}" \\
  data.max_prompt_length="\${max_prompt_length}" \\
  data.max_response_length="\${max_response_length}" \\
  data.filter_overlong_prompts="\${filter_overlong_prompts}" \\
  data.truncation="\${truncation}" \\
  data.dataloader_num_workers="\${dataloader_num_workers}" \\
  data.clpo_hard_acc_upper="\${clpo_hard_acc_upper}" \\
  data.clpo_medium_acc_lower="\${clpo_med_acc_lower}" \\
  data.clpo_medium_acc_upper="\${clpo_med_acc_upper}" \\
  data.clpo_save_rewrite_data="\${clpo_save_rewrite_data}" \\
  data.clpo_rewrite_save_path="\${clpo_rewrite_save_path}" \\
  data.clpo_hard_rewrite_save_path="\${clpo_hard_rewrite_save_path}" \\
  data.clpo_medium_rewrite_save_path="\${clpo_medium_rewrite_save_path}" \\
  actor_rollout_ref.model.path="\${model_path}" \\
  actor_rollout_ref.model.enable_gradient_checkpointing="\${enable_gradient_checkpointing}" \\
  actor_rollout_ref.model.use_remove_padding="\${use_remove_padding}" \\
  actor_rollout_ref.actor.optim.lr="\${actor_lr}" \\
  actor_rollout_ref.actor.optim.lr_warmup_steps="\${actor_lr_warmup_steps}" \\
  actor_rollout_ref.actor.optim.warmup_style="\${warmup_style}" \\
  actor_rollout_ref.actor.ppo_mini_batch_size="\${ppo_mini_batch_size}" \\
  actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu="\${ppo_micro_batch_size_per_gpu}" \\
  actor_rollout_ref.actor.use_kl_loss="\${use_kl_loss}" \\
  actor_rollout_ref.actor.kl_loss_coef="\${kl_loss_coef}" \\
  actor_rollout_ref.actor.kl_loss_coef_hard_scale="\${kl_loss_coef_hard_scale}" \\
  actor_rollout_ref.actor.kl_loss_coef_nonhard_scale="\${kl_loss_coef_nonhard_scale}" \\
  actor_rollout_ref.actor.entropy_coeff="\${entropy_coeff}" \\
  actor_rollout_ref.actor.fsdp_config.param_offload="\${param_offload}" \\
  actor_rollout_ref.actor.fsdp_config.optimizer_offload="\${optimizer_offload}" \\
  actor_rollout_ref.rollout.name="\${rollout_name}" \\
  actor_rollout_ref.rollout.n="\${n_resp_per_prompt}" \\
  actor_rollout_ref.rollout.tensor_model_parallel_size="\${tensor_model_parallel_size}" \\
  actor_rollout_ref.rollout.gpu_memory_utilization="\${gpu_memory_utilization}" \\
  actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu="\${log_prob_micro_batch_size_per_gpu}" \\
  actor_rollout_ref.rollout.max_model_len="\${max_model_len}" \\
  actor_rollout_ref.rollout.max_num_batched_tokens="\${max_num_batched_tokens}" \\
  actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu="\${log_prob_micro_batch_size_per_gpu}" \\
  actor_rollout_ref.ref.fsdp_config.param_offload="\${param_offload}" \\
  trainer.logger='["console", "swanlab"]' \\
  trainer.project_name="\${project_name}" \\
  trainer.experiment_name="\${experiment_name}" \\
  trainer.total_epochs="\${total_epochs}" \\
  trainer.critic_warmup="\${critic_warmup}" \\
  trainer.test_freq="\${test_freq}" \\
  trainer.save_freq="\${save_freq}" \\
  trainer.val_before_train="\${val_before_train}" \\
  trainer.n_gpus_per_node="\${ngpus_per_node}" \\
  trainer.nnodes="\${nnodes}" \\
  trainer.default_local_dir="\${output_dir}" \\
  trainer.task=clpo \\
  "\$@"

training_exit_code=\$?

echo ""
echo "🏁 ===== ABLATION EXPERIMENT COMPLETED ====="

if [ \$training_exit_code -eq 0 ]; then
    echo "✅ KL-in-Loss Coefficient Sensitivity (${loss_coef}, ${hard_scale}, ${nonhard_scale}) experiment completed successfully!"
    echo "📁 Checkpoints saved to: \${output_dir}"
    echo "📊 Experiment name: \${experiment_name}"
else
    echo "❌ Experiment failed with exit code: \$training_exit_code"
fi

exit \$training_exit_code
EOF

    chmod +x "${script_path}"
}

# ========================
# Main Execution
# ========================
echo "===== GENERATING COEFFICIENT SENSITIVITY EXPERIMENTS ====="

# Generate KL-in-Reward coefficient sensitivity experiments
echo "Generating KL-in-Reward coefficient sensitivity experiments..."
for combo in "${REWARD_COEF_COMBINATIONS[@]}"; do
    IFS=',' read -r hard_coef nonhard_coef <<< "${combo}"
    generate_reward_sensitivity_script "${hard_coef}" "${nonhard_coef}"
done

# Generate KL-in-Loss coefficient sensitivity experiments
echo "Generating KL-in-Loss coefficient sensitivity experiments..."
for combo in "${LOSS_COEF_COMBINATIONS[@]}"; do
    IFS=',' read -r loss_coef hard_scale nonhard_scale <<< "${combo}"
    generate_loss_sensitivity_script "${loss_coef}" "${hard_scale}" "${nonhard_scale}"
done

echo "===== GENERATION COMPLETED ====="
echo "Generated scripts:"
ls -la "${SCRIPT_DIR}"/ablation_kl_*_coef_sensitivity_*.sh

echo ""
echo "To run all coefficient sensitivity experiments:"
echo "bash ${SCRIPT_DIR}/run_all_coefficient_sensitivity_experiments.sh"
