#!/bin/bash

# SwanLab SFT Training Script
# Usage: ./train_swanlab.sh [gpu_id] [project_name] [mode] [api_key]

GPU_ID=${1:-0}
PROJECT_NAME=${2:-"sft-reasoning"}
SWANLAB_MODE=${3:-"cloud"}  # cloud, offline, disabled, or local
API_KEY=${4:-""}
DATA_FILE="../data/train-en.jsonl" # Path to your training data (train-en.jsonl/train-el.jsonl)
VAL_DATA_FILE=""
MODEL_PATH="path/to/your/model"
OUTPUT_DIR="output-dir"
LOG_FILE="./rectraining.log"
USE_RECALL_ADAM=false #true/false
LEARNING_RATE=2e-8 #2e-4/1e-4/5e-5
REC_PRETRAIN_COF=3000


echo "Starting SFT training with SwanLab monitoring..."
echo "GPU: $GPU_ID"
echo "Project: $PROJECT_NAME"
echo "SwanLab Mode: $SWANLAB_MODE"
echo "Data: $DATA_FILE"
echo "Output: $OUTPUT_DIR"
echo "Use Recall Adam:$USE_RECALL_ADAM"
echo "Rec Pretrain Adam:$REC_PRETRAIN_COF"
echo "Learning rate:$LEARNING_RATE"
echo "Log: $LOG_FILE"

# Create output directory
mkdir -p $OUTPUT_DIR

# Set GPU environment
export CUDA_VISIBLE_DEVICES=$GPU_ID

# Fix for RTX 4000 series GPU communication issues
export NCCL_P2P_DISABLE="1"
export NCCL_IB_DISABLE="1"

# Set SwanLab API key if provided
if [ ! -z "$API_KEY" ]; then
    export SWANLAB_API_KEY="$API_KEY"
    echo "SwanLab API key set from parameter"
elif [ ! -z "$SWANLAB_API_KEY" ]; then
    echo "SwanLab API key found in environment"
else
    echo "No SwanLab API key provided - using default settings"
fi

# Build SwanLab arguments
SWANLAB_ARGS="--use_swanlab --swanlab_project $PROJECT_NAME --swanlab_mode $SWANLAB_MODE"

# Add LoRA flag
LORA_ARGS="--use_lora"

# Generate experiment name with timestamp
TIMESTAMP=$(date +"%Y%m%d_%H%M%S")
EXPERIMENT_NAME="sft_$(basename $DATA_FILE .jsonl)_$TIMESTAMP"

echo "Experiment name: $EXPERIMENT_NAME"
echo "Starting training..."

nohup python sft.py \
    --data_path $DATA_FILE \
    --val_data_path $VAL_DATA_FILE \
    --model_path $MODEL_PATH \
    --max_seq_length 2048 \
    $LORA_ARGS \
    $SWANLAB_ARGS \
    --swanlab_experiment_name $EXPERIMENT_NAME \
    --swanlab_description "SFT training on $(basename $DATA_FILE) with LoRA" \
    --output_dir $OUTPUT_DIR \
    --per_device_train_batch_size 1 \
    --gradient_accumulation_steps 8 \
    --num_train_epochs 5 \
    --learning_rate $LEARNING_RATE \
    --warmup_ratio 0.1 \
    --lr_scheduler_type "cosine" \
    --logging_steps 10 \
    --save_total_limit 5 \
    --bf16 True \
    --max_grad_norm 1.0 \
    --dataloader_num_workers 4 \
    --remove_unused_columns False \
    --seed 42 \
    --use_recall_adam $USE_RECALL_ADAM \
    --rec_anneal_fun sigmoid \
    --rec_anneal_k 0.2 \
    --rec_anneal_t0 100 \
    --rec_anneal_w 1.0 \
    --rec_pretrain_cof $REC_PRETRAIN_COF \

    > $LOG_FILE 2>&1 &

TRAIN_PID=$!

echo "✅ Training started!"
echo "   PID: $TRAIN_PID"
echo "   Monitor: tail -f $LOG_FILE"
echo "   Stop: kill $TRAIN_PID"
echo ""
echo "💡 Tips:"
echo "   - Training logs: tail -f $LOG_FILE"
echo "   - GPU usage: nvidia-smi"
echo "   - Stop training: kill $TRAIN_PID"
echo ""
echo "🔧 SwanLab Modes:"
echo "   - cloud: Real-time sync to cloud (requires API key)"
echo "   - offline: Local logging only, no cloud sync"
echo "   - local: Local SwanLab server"
echo "   - disabled: No SwanLab tracking"
