CUDA_VISIBLE_DEVICES=1,2 torchrun --nnodes 1 --nproc_per_node 2 finetuning.py \
--batch_size_training 2 --lr 5e-5 \
--num_epochs 10 \
--dataset aoa_dataset \
--enable_fsdp \
--model_name ckpts/decapoda-research-llama-7B-hf-prune/ --pure_bf16 \
--tune_prune_LLM  True \
--prune_ckpt ../LLM-Pruner/prune_log/llama_prune/decapoda-research-llama-7B-hf/pytorch_model.bin \
--lora_ckpt ../LLM-Pruner/tune_log/llama_2_0.2 \

python inference/checkpoint_converter_fsdp_hf.py \
-fsdp_checkpoint_path fsdp/fine-tuned-ckpts/decapoda-research-llama-7B-hf-prune/ \
-consolidated_model_path finetuned_models/7b-prune/aoa-epoch=10 \
-HF_model_path_or_name ckpts/decapoda-research-llama-7B-hf-prune/
