#!/bin/bash


ROOT_DIR=""
NUM_ROUNDS=5
CLIENT_DATASETS=("r1_200" "32b_200" "235b_200" "qwq_200")
INIT_MODEL="/../checkpoints/base/Qwen3-8B"
TRAIN_SCRIPT="${ROOT_DIR}/Merge-of-Thought/src/train.py"
DEEPSPEED_CONFIG="${ROOT_DIR}/llamafactory_config/llamafactory_config/deepspeed/ds_z3_offload.json"
BASE_OUTPUT_DIR="${ROOT_DIR}/Merge-of-Thought/saves/boba_64"
MERGE_SCRIPT="${ROOT_DIR}/Merge-of-Thought/merge_models.py"


cat > "$MERGE_SCRIPT" << 'EOF'
import torch
from transformers import AutoModelForCausalLM, AutoConfig
import os
import argparse
import shutil

def merge_models(model_dirs, output_dir):
    print(f"⚙️ begin merge: {model_dirs} -> {output_dir}")
    os.makedirs(output_dir, exist_ok=True)
    
    
    config_files = [
        'config.json', 'generation_config.json',
        'tokenizer_config.json', 'vocab.json', 
        'tokenizer.json', 'special_tokens_map.json'
    ]
    
    for file in config_files:
        src_path = os.path.join(model_dirs[0], file)
        if os.path.exists(src_path):
            shutil.copy(src_path, output_dir)
            print(f"  copy: {file}")
    
    
    print("loading...")
    base_model = AutoModelForCausalLM.from_pretrained(
        model_dirs[0],
        torch_dtype=torch.float16,
        device_map=None  
    )
    avg_state_dict = base_model.state_dict()  
    
    
    for peer_dir in model_dirs[1:]:
        print(f"weight ftom: {peer_dir}")
        peer_model = AutoModelForCausalLM.from_pretrained(
            peer_dir,
            torch_dtype=torch.float16,
            device_map=None
        )
        peer_state_dict = peer_model.state_dict()
        
        
        for name in avg_state_dict:
            if name in peer_state_dict:
                
                if peer_state_dict[name].dtype != avg_state_dict[name].dtype:
                    peer_state_dict[name] = peer_state_dict[name].to(avg_state_dict[name].dtype)
                avg_state_dict[name] += peer_state_dict[name]
        
        
        del peer_model, peer_state_dict
        torch.cuda.empty_cache()
    
    
    num_models = len(model_dirs)
    for name in avg_state_dict:
        avg_state_dict[name] = avg_state_dict[name] / num_models
    
    
    print(f"save: {output_dir}")
    base_model.load_state_dict(avg_state_dict)
    
    
    base_model.save_pretrained(
        output_dir,
        state_dict=avg_state_dict
    )
    
    
    base_model.config.save_pretrained(output_dir)
    
    
    tokenizer_files = [
        'tokenizer_config.json', 'vocab.json', 
        'tokenizer.json', 'special_tokens_map.json'
    ]
    for file in tokenizer_files:
        src = os.path.join(model_dirs[0], file)
        if os.path.exists(src) and not os.path.exists(os.path.join(output_dir, file)):
            shutil.copy(src, output_dir)
    
    print("✅ over")
    del base_model, avg_state_dict
    torch.cuda.empty_cache()

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='merge models from multiple directories')
    parser.add_argument("--input_dirs", nargs="+", required=True, 
                        help="required dictory list to merge")
    parser.add_argument("--output_dir", required=True,
                        help="required output directory to save merged model")
    args = parser.parse_args()
    merge_models(args.input_dirs, args.output_dir)
EOF


current_round_model="$INIT_MODEL"
for round in $(seq 1 $NUM_ROUNDS); do
    echo -e "\n\033[1;34m===== merge round $round/$NUM_ROUNDS =====\033[0m"
    
    
    client_dirs=()
    for i in {0..3}; do
        dataset="merge_sft_${CLIENT_DATASETS[$i]}"
        output_dir="${BASE_OUTPUT_DIR}/round_${round}/${CLIENT_DATASETS[$i]}"
        client_dirs+=("$output_dir")
        mkdir -p "$output_dir"
        
        
        echo -e "\033[1;32mclient $((i+1))/4: $dataset (8gpu)\033[0m"
        
        
        CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
        torchrun --nproc_per_node=8 "$TRAIN_SCRIPT" \
            --model_name_or_path "$current_round_model" \
            --stage sft \
            --do_train \
            --finetuning_type full \
            --deepspeed "$DEEPSPEED_CONFIG" \
            --dataset_dir "${ROOT_DIR}/Merge-of-Thought/data" \
            --dataset "$dataset" \
            --template qwen3 \
            --cutoff_len 25000 \
            --overwrite_cache \
            --preprocessing_num_workers 16 \
            --output_dir "$output_dir" \
            --logging_steps 1 \
            --save_steps 50 \
            --save_total_limit 10 \
            --save_only_model \
            --per_device_train_batch_size 1 \
            --gradient_accumulation_steps 8 \
            --learning_rate 1.0e-5 \
            --max_steps 50 \
            --lr_scheduler_type cosine \
            --warmup_ratio 0.01 \
            --weight_decay 0.1 \
            --adam_beta1 0.9 \
            --adam_beta2 0.95 \
            --max_grad_norm 1.0 \
            --bf16 \
            --flash_attn fa2
            --report_to none \
            --ddp_timeout 180000000 \
            2>&1 | tee "$output_dir/train.log"
        
        echo -e "\033[1;32m✅ branch $((i+1))/4 finish\033[0m"
    done

    echo -e "\n\033[1;32m✅ all branch finish ${round} training\033[0m"

    # 合并模型参数
    merged_model="${BASE_OUTPUT_DIR}/merged_round_${round}"
    echo -e "\033[1;35m�� begin merge -> ${merged_model}\033[0m"
    python "$MERGE_SCRIPT" \
        --input_dirs "${client_dirs[@]}" \
        --output_dir "$merged_model"
    
    # 更新下一轮模型
    current_round_model="$merged_model"
    echo -e "\n\033[1;35m=== round ${round} finish! global model updated: ${current_round_model} ===\033[0m"
    
    # 清理临时内存
    echo "clean gpu memory..."
    nvidia-smi | grep python | awk '{print $5}' | xargs -I{} kill {} > /dev/null 2>&1
    echo "clean cache..."
    sudo sync; sudo sysctl -w vm.drop_caches=3 > /dev/null
done

echo -e "\n\033[1;32m✨ training finish! model save: ${current_round_model}\033[0m"
echo -e "test model:"
echo -e "  python -c \"from transformers import AutoModel; AutoModel.from_pretrained('${current_round_model}')\""