name: olmo-7b-instruct
image: mosaicml/pytorch:2.1.0_cu121-python3.10-ubuntu20.04
compute:
  #cluster: r12z3
  cluster: r7z2
  gpus: 64
  gpu_type: a100_40gb
integrations:
  - integration_type: git_repo
    git_repo: allenai/LLM
    git_branch: main
    pip_install: -e .
    ssh_clone: true
command: |-
  checkpoint=s3://olmo-checkpoints/ai2-llm/olmo-medium/wd2gxrza/step556000-unsharded
  learning_rate=2e-6
  run_name=mitchish-mcli-2.5T-instruct-${learning_rate}-5ep-v2

  # NOTE: For some reason getting S3 and R2 authentication working both from the command line and
  # from Python proved to be challenging, maybe because Mosaic's server are in Australia.
  # In the end I had to use separate methods to get everything working:
  #  1. AWS config files for CLI access.
  #  2. Environment variables for boto3 access (to S3 only).
  # Since we only need CLI access prior to training, we remove the AWS config files before launching
  # the training job. Otherwise the environment variables won't work.

  # Install aws cli
  apt-get update
  apt-get install zip unzip
  curl "https://awscli.amazonaws.com/awscli-exe-linux-x86_64.zip" -o "awscliv2.zip"
  unzip awscliv2.zip
  sudo ./aws/install

  cd LLM

  pip freeze

  # Prepare environment including AWS config files for both S3 and R2 access.
  mkdir -p /root/.cache/torch
  mkdir /root/checkpoint-unsharded
  mkdir /root/.aws
  touch /root/.aws/credentials /root/.aws/config
  echo '[s3]' >> /root/.aws/credentials
  echo "aws_access_key_id = ${AWS_ACCESS_KEY_ID}" >> /root/.aws/credentials
  echo "aws_secret_access_key = ${AWS_SECRET_ACCESS_KEY}" >> /root/.aws/credentials
  echo '' >> /root/.aws/credentials
  echo '[r2]' >> /root/.aws/credentials
  echo "aws_access_key_id = ${R2_ACCESS_KEY_ID}" >> /root/.aws/credentials
  echo "aws_secret_access_key = ${R2_SECRET_ACCESS_KEY}" >> /root/.aws/credentials
  echo "[default]" >> /root/.aws/config
  echo "region = auto" >> /root/.aws/config
  echo "output = json" >> /root/.aws/config

  #export S3_PROFILE=s3
  #export R2_PROFILE=r2
  export OMP_NUM_THREADS=8
  export LOG_FILTER_TYPE=local_rank0_only

  # Download checkpoint (everything except optimizer state).
  echo "Downloading checkpoint '${checkpoint}'..."

  # Download config.
  aws s3 cp --profile=r2 --region=auto \
    --endpoint-url=https://a198dc34621661a1a66a02d6eb7c4dc3.r2.cloudflarestorage.com \
    "${checkpoint}/config.yaml" /root/checkpoint-unsharded/

  # Download trainer state.
  aws s3 cp --profile=r2 --region=auto \
    --endpoint-url=https://a198dc34621661a1a66a02d6eb7c4dc3.r2.cloudflarestorage.com \
    "${checkpoint}/train.pt" /root/checkpoint-unsharded/

  # Download model weights.
  aws s3 cp --profile=r2 --region=auto \
    --endpoint-url=https://a198dc34621661a1a66a02d6eb7c4dc3.r2.cloudflarestorage.com \
    "${checkpoint}/model.pt" /root/checkpoint-unsharded/

  # Download optimizer state.
  #aws s3 cp --profile=r2 --region=auto \
  #  --endpoint-url=https://a198dc34621661a1a66a02d6eb7c4dc3.r2.cloudflarestorage.com \
  #  "${checkpoint}/optim.pt" /root/checkpoint-unsharded/

  # Now remove the aws configs so it doesn't mess with data loading / uploading checkpoints to/from S3.
  rm -rf /root/.aws

  torchrun \
  --master_addr "$MASTER_ADDR" \
  --master_port "$MASTER_PORT" \
  --nnodes "$NUM_NODES" \
  --node_rank "$NODE_RANK" \
  --nproc_per_node 8 \
  scripts/train.py configs/mitchish-instruct.yaml \
    --run_name=${run_name} \
    --optimizer.learning_rate=${learning_rate} \
    --scheduler.grad_clip_warmup_steps=400 \
    --save_overwrite \
    --save_interval_unsharded=100000 \
    --load_path=/root/checkpoint-unsharded \
    --reset_trainer_state \
    --reset_optimizer_state \
    --compile=null \
    --model.flash_attention=true \
    --activation_checkpointing=whole_layer \
    --fsdp.wrapping_strategy=size_based \
    --max_duration=5ep
