#!/bin/bash
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True

MODEL="/path/to/your/hybrid_model"
OUTPUT_DIR="/path/to/your/output_dir"

# Distributed training configuration
nnodes=1
nproc_per_node=8
BATCHSIZE=1

# Launch training job
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
NNODES=$nnodes \
NPROC_PER_NODE=$nproc_per_node \
MAX_PIXELS=200704 \
VIDEO_MAX_PIXELS=50176 \
FPS_MAX_FRAMES=64 \
swift rlhf \
    --rlhf_type gkd \
    --teacher_model Qwen/Qwen2.5-VL-7B-Instruct \
    --model $MODEL \
    --train_type full \
    --dataset lmms-lab/LLaVA-Video-178K \
    --split_dataset_ratio 0 \
    --attn_impl flash_attn \
    --dataset_num_proc 1 \
    --adam_beta1 0.9 \
    --adam_beta2 0.95 \
    --adam_epsilon 1e-8 \
    --weight_decay 0.1 \
    --lr_scheduler_type 'cosine' \
    --learning_rate 1e-6 \
    --max_length 8192 \
    --num_train_epochs 1 \
    --warmup_ratio 0.03 \
    --per_device_train_batch_size 1 \
    --per_device_eval_batch_size 1 \
    --gradient_accumulation_steps 8 \
    --eval_steps 1000 \
    --deepspeed zero3_offload \
    --gradient_checkpointing true \
    --save_steps 1000 \
    --save_only_model false \
    --save_total_limit 10 \
    --logging_steps 1 \
    --lmbda 0 \
    --beta 1 \
    --sft_alpha 0 \
    --output_dir $OUTPUT_DIR \
    --torch_dtype bfloat16 \
    --freeze_parameters_ratio 1 \
    --trainable_parameters model.language_model