#!/bin/bash

#SBATCH --job-name=llama_full
#SBATCH --output=llama_full.out
#SBATCH --error=llama_full.err
# SBATCH --job-name=llama_full_scrp
# SBATCH --output=llama_full_scrp.out
# SBATCH --error=llama_full_scrp.err

#SBATCH --partition=compute
#SBATCH --nodes=1
#SBATCH --ntasks-per-node=1
#SBATCH --cpus-per-task=1
#SBATCH --gres=gpu:A100:3
#SBATCH --time=23:00:00
#SBATCH --mail-type=ALL
#SBATCH --mail-user=bo@andrew.cmu.edu

source ~/.bashrc
conda activate agent

WANDB__SERVICE_WAIT=300 WANDB_PROJECT=llama accelerate launch  --main_process_port 20504 --num_processes 3 --num_machines 1  /data/b_ou/agent-model/LLaMA-Factory/src/train_bash.py \
    --stage sft \
    --model_name_or_path /data/b_ou/ckpts/llama2/models--meta-llama--Llama-2-7b-chat-hf/snapshots/c1b0db933684edbfe29a06fa47eb19cc48025e93 \
    --cache_path /data/b_ou/ckpts/data_cache/llama2 \
    --do_train \
    --do_eval \
    --dataset m2w_text \
    --train_size 16500 \
    --shuffle False \
    --dataset_dir /data/b_ou/agent/data/text/ \
    --template llama2 \
    --finetuning_type full \
    --output_dir /data/b_ou/ckpts/output_16k_llama2_full/ \
    --overwrite_output_dir True \
    --overwrite_cache \
    --per_device_train_batch_size 1 \
    --per_device_eval_batch_size 1 \
    --gradient_accumulation_steps 16 \
    --gradient_checkpointing True \
    --lr_scheduler_type cosine \
    --evaluation_strategy "steps" \
    --save_strategy "epoch" \
    --logging_steps "30" \
    --save_total_limit 1 \
    --learning_rate 5e-5 \
    --num_train_epochs 5 \
    --plot_loss \
    --bf16 True \
    --cutoff_len 4096 \
    --fsdp "full_shard auto_wrap" \
    --fsdp_transformer_layer_cls_to_wrap 'LlamaDecoderLayer' \
    --report_to 'wandb'

        # --eval_steps 1 \




# WANDB__SERVICE_WAIT=300 WANDB_PROJECT=llama accelerate launch  --main_process_port 20504 --num_processes 8 --num_machines 1  /data/b_ou/agent-model/LLaMA-Factory/src/train_bash.py \
#     --stage sft \
#     --model_name_or_path /data/b_ou/ckpts/llama2/models--meta-llama--Llama-2-7b-chat-hf/snapshots/c1b0db933684edbfe29a06fa47eb19cc48025e93 \
#     --do_train \
#     --do_eval \
#     --dataset m2w_text_scrape \
#     --train_size 4 \
#     --shuffle False \
#     --dataset_dir /data/b_ou/agent/data/text_scrp_short/ \
#     --template llama2 \
#     --output_dir /data/b_ou/ckpts/output_8k_llama_full_scrp/ \
#     --overwrite_output_dir True \
#     --overwrite_cache \
#     --per_device_train_batch_size 1 \
#     --per_device_eval_batch_size 1 \
#     --gradient_accumulation_steps 16 \
#     --gradient_checkpointing True \
#     --lr_scheduler_type cosine \
#     --evaluation_strategy "steps" \
#     --save_strategy "epoch" \
#     --logging_steps "2" \
#     --eval_accumulation_steps 1 \
#     --save_total_limit 1 \
#     --learning_rate 5e-5 \
#     --num_train_epochs 5 \
#     --plot_loss \
#     --bf16 True \
#     --cutoff_len 4096 \
#     --predict_with_generate True \
#     --finetuning_type lora \
#     --lora_target q_proj,v_proj \
#     --report_to 'wandb'


#         --model_max_length 4096 \

#     --fsdp_transformer_layer_cls_to_wrap 'LlamaDecoderLayer' \

#     --fsdp "full_shard auto_wrap" \
#     --finetuning_type full \
#     --cache_path /data/b_ou/ckpts/data_cache/llama2_scrp \
