name: olmo-70b
image: mosaicml/pytorch:2.2.1_cu121-python3.11-ubuntu20.04
# image: public.ecr.aws/z0f8p3z5/olmo:pytorch2.2.1_cu121-python3.11-ubuntu20.04
# image: us-central1-docker.pkg.dev/ai2-olmo/olmo/pytorch:2.2.1_cu121-python3.11-ubuntu20.04
scheduling:
  priority: auto
  # preemptible: true  # means it can be retried
  # max_retries: 10
compute:
  cluster: r15z4
  gpus: 896
  gpu_type: h100_80gb
  instance: oci.bm.gpu.h100.8
  # node_names:
integrations:
  - integration_type: git_repo
    git_repo: allenai/OLMo
    git_branch: train-olmo-large
    pip_install: -e .[train]
    ssh_clone: true
  - integration_type: git_repo
    git_repo: allenai/OLMo-core
    git_branch: WorksTorch22
    pip_install: -e .
    ssh_clone: true
env_variables:
  PIP_DISABLE_PIP_VERSION_CHECK: "1"
  OMP_NUM_THREADS: "8"
  LOG_FILTER_TYPE: local_rank0_only
command: |-
  # Make sure we have a recent flash-attn.
  # NOTE: only pinning flash-attn here to future proof it.
  pip install flash-attn==2.5.3 --no-build-isolation
  # Install AWS CLI (for pre-downloading unsharded checkpoints).
  pip install awscli

  # Show packages for debugging.
  pip freeze

  # Prepare environment.
  mkdir -p /root/.cache/torch
  # warm up huggingface cache
  pushd /root/.cache
  curl "https://storage.googleapis.com/dirkgr-public/huggingface_cache_v3.tar.gz" | tar -xzf -
  popd
  export HF_DATASETS_OFFLINE=1

  #checkpoint=s3://ai2-llm/checkpoints/OLMo-large/mitchish70-002/step160500-unsharded-hacked
  #mkdir /root/checkpoint-unsharded
  #aws s3 cp --no-progress ${checkpoint}/config.yaml /root/checkpoint-unsharded/
  #aws s3 cp --no-progress ${checkpoint}/train.pt /root/checkpoint-unsharded/
  #aws s3 cp --no-progress ${checkpoint}/model.safetensors /root/checkpoint-unsharded/
  #aws s3 cp --no-progress ${checkpoint}/optim.safetensors /root/checkpoint-unsharded/

  cd OLMo

  echo "Launching train script..."
  torchrun \
  --master_addr "$MASTER_ADDR" \
  --master_port "$MASTER_PORT" \
  --nnodes "$NUM_NODES" \
  --node_rank "$NODE_RANK" \
  --nproc_per_node 8 \
  scripts/train.py configs/mitchish70-s3.yaml \
    --run_name=mitchish70-pland \
    '--wandb.group=${run_name}' \
    '--load_path=${path.last_checkpoint:${remote_save_folder}}' \
    --load_path_sharded_checkpointer=olmo_core \
    --sharded_checkpointer=olmo_core \
    --optimizer.learning_rate=0.000075 \
    --global_train_batch_size=3584 \
    --device_train_microbatch_size=4 \
    --fsdp.sharding_strategy=HYBRID_SHARD \
    --fsdp.hybrid_sharding_num_model_replicas=4 \
    --time_limit=604800 \
    --save_overwrite

#
# --fsdp.sharding_strategy=HYBRID_SHARD \
# --fsdp.hybrid_sharding_num_model_replicas=4 \
# 
#    '--load_path=${path.last_checkpoint:${remote_save_folder}}' \
#    --load_path=s3://ai2-llm/checkpoints/OLMo-large/mitchish70-planc/step197000 \
#    --load_path=s3://ai2-llm/checkpoints/OLMo-large/mitchish70-002/step48950 \
#    --load_path=s3://ai2-llm/checkpoints/OLMo-large/mitchish70-002/step49000 \
#    --load_path=/root/checkpoint-unsharded \
#
#  gpus: 256
#    --global_train_batch_size=1536 \
#  gpus: 384
#    --global_train_batch_size=1536 \
#    --device_train_microbatch_size=2 \
#  gpus: 896
#    --global_train_batch_size=1792 \
#  gpus: 600  # (75 nodes)
#    --global_train_batch_size=1800 \
