version: v2
description: open-instruct-finetune-multinode-test
tasks:
  - name: open-instruct-finetune-multinode-test
    replicas: 4
    leaderSelection: true
    hostNetworking: true
    image:
      beaker: Yizhongw03/open-instruct-multi-node
    command: [
      '/bin/sh', '-c'
    ]
    arguments: ['
        unset CUDA_LAUNCH_BLOCKING && accelerate launch
        --mixed_precision bf16
        --num_machines 4
        --num_processes 32
        --machine_rank $BEAKER_REPLICA_RANK
        --main_process_ip $BEAKER_LEADER_REPLICA_HOSTNAME
        --main_process_port 29400
        --use_deepspeed
        --deepspeed_config_file ds_configs/stage3_no_offloading_accelerate.conf
        --deepspeed_multinode_launcher standard
        open_instruct/finetune.py
        --model_name_or_path /net/nfs.cirrascale/allennlp/yizhongw/hf_llama2_models/70B
        --tokenizer_name /net/nfs.cirrascale/allennlp/yizhongw/hf_llama2_models/70B
        --use_slow_tokenizer
        --train_file /net/nfs.cirrascale/allennlp/yizhongw/open-instruct-public/data/processed/sharegpt/sharegpt_data.jsonl
        --use_flash_attn
        --max_seq_length 1024
        --preprocessing_num_workers 64
        --per_device_train_batch_size 1
        --gradient_accumulation_steps 4
        --learning_rate 2e-5
        --lr_scheduler_type linear
        --warmup_ratio 0.03
        --weight_decay 0.
        --num_train_epochs 5
        --output_dir /output/
        --with_tracking
        --report_to tensorboard
        --logging_steps 1
    ']
    envVars:
      - name: CUDA_DEVICE_ORDER
        value: PCI_BUS_ID
      - name: TRANSFORMERS_CACHE
        value: ./cache/
      - name: WANDB_PROJECT
        value: open-instruct
      - name: WANDB_WATCH
        value: false
      - name: WANDB_LOG_MODEL
        value: false
      - name: WANDB_DISABLED
        value: true
      - name: NCCL_NET
        value: IB
      - name: NCCL_DEBUG
        value: INFO
    datasets:
      - mountPath: /net/nfs.cirrascale
        source:
          hostPath: /net/nfs.cirrascale
    result:
      path: /output
    resources:
      gpuCount: 8
    context:
      priority: high
    constraints:
      cluster: [ai2/general-cirrascale-a100-80g-ib]