name: olmo-7b
image: mosaicml/pytorch:2.2.1_cu121-python3.11-ubuntu20.04
compute:
  cluster: r15z4
  gpus: 64
  gpu_type: h100_80gb
  instance: oci.bm.gpu.h100.8
integrations:
  - integration_type: git_repo
    git_repo: allenai/OLMo
    git_branch: train-olmo-large
    pip_install: -e .[train]
    ssh_clone: true
  - integration_type: git_repo
    git_repo: allenai/OLMo-core
    git_branch: main
    pip_install: -e .
    ssh_clone: true
env_variables:
  PIP_DISABLE_PIP_VERSION_CHECK: "1"
  OMP_NUM_THREADS: "8"
  LOG_FILTER_TYPE: local_rank0_only
command: |-
  # Install AWS CLI (for download unsharded checkpoints).
  #apt-get update
  #apt-get install zip unzip
  #curl "https://awscli.amazonaws.com/awscli-exe-linux-x86_64.zip" -o "awscliv2.zip"
  #unzip awscliv2.zip
  #sudo ./aws/install

  # Make sure we have a recent flash-attn.
  # NOTE: only pinning flash-attn here to future proof it.
  pip install flash-attn==2.5.3 --no-build-isolation

  # Show packages for debugging.
  pip freeze

  # Prepare environment.
  mkdir -p /root/.cache/torch
  # warm up huggingface cache
  pushd /root/.cache
  curl "https://storage.googleapis.com/dirkgr-public/huggingface_cache_v3.tar.gz" | tar -xzf -
  popd
  export HF_DATASETS_OFFLINE=1

  cd OLMo

  torchrun \
  --master_addr "$MASTER_ADDR" \
  --master_port "$MASTER_PORT" \
  --nnodes "$NUM_NODES" \
  --node_rank "$NODE_RANK" \
  --nproc_per_node 8 \
  scripts/train.py configs/mitchish7-s3.yaml \
    --run_name=mitchish7 \
    --wandb.group=mitchish7 \
    --model.flash_attention=true \
    --fsdp.wrapping_strategy=by_block_and_size \
    --fsdp.sharding_strategy=SHARD_GRAD_OP \
    --save_folder=runs/ \
    --activation_checkpointing=fine_grained \
    --fused_loss=true \
    --device_train_microbatch_size=2 \
    --global_train_batch_size=1024 \
    --gen1_gc_interval=32 \
    '--load_path=${path.last_checkpoint:${remote_save_folder}}' \
    --save_overwrite
