name: olmo-70b
image: mosaicml/pytorch:2.2.1_cu121-python3.11-ubuntu20.04
compute:
  cluster: r9z3
  gpus: 256
  gpu_type: h100_80gb
integrations:
  - integration_type: git_repo
    git_repo: allenai/OLMo
    git_branch: mitchish65-2
    pip_install: -e .[train]
    ssh_clone: true
env_variables:
  PIP_DISABLE_PIP_VERSION_CHECK: "1"
  OMP_NUM_THREADS: "8"
  LOG_FILTER_TYPE: local_rank0_only
command: |-
  # Install AWS CLI (for download unsharded checkpoints).
  #apt-get update
  #apt-get install zip unzip
  #curl "https://awscli.amazonaws.com/awscli-exe-linux-x86_64.zip" -o "awscliv2.zip"
  #unzip awscliv2.zip
  #sudo ./aws/install

  # Make sure we have a recent flash-attn.
  # NOTE: only pinning flash-attn here to future proof it.
  pip install flash-attn==2.5.3 --no-build-isolation

  # Show packages for debugging.
  pip freeze

  # Prepare environment.
  cd OLMo
  mkdir -p /root/.cache/torch

  torchrun \
  --master_addr "$MASTER_ADDR" \
  --master_port "$MASTER_PORT" \
  --nnodes "$NUM_NODES" \
  --node_rank "$NODE_RANK" \
  --nproc_per_node 8 \
  scripts/train.py configs/mitchish70-s3.yaml \
    --run_name=mitchish70-001 \
    --wandb.group=mitchish70 \
    --global_train_batch_size=768 \
    --device_train_microbatch_size=3 \
    --save_overwrite
