run_name: olmo-7b-final
image: mosaicml/pytorch:2.1.0_cu121-python3.10-ubuntu20.04
gpu_num: 64
cluster: r12z3
#cluster: r7z2
gpu_type: a100_40gb
integrations:
  - integration_type: git_repo
    git_repo: allenai/LLM
    git_branch: main
    pip_install: -e .
    ssh_clone: true
command: |-
  # NOTE: For some reason getting S3 and R2 authentication working both from the command line and
  # from Python proved to be challenging, maybe because Mosaic's server are in Australia.
  # In the end I had to use separate methods to get everything working:
  #  1. AWS config files for CLI access.
  #  2. Environment variables for boto3 access (to S3 only).
  # Since we only need CLI access prior to training, we remove the AWS config files before launching
  # the training job. Otherwise the environment variables won't work.

  # Adjust these vars as needed.
  #checkpoint=s3://olmo-checkpoints/ai2-llm/olmo-medium/svtto91c/step456000-unsharded
  #run_name=mitchish-lumi-2T-final
  checkpoint=s3://olmo-checkpoints/ai2-llm/olmo-medium/ho7jy4ey/step432410-unsharded
  run_name=mitchish-mcli-2T-final
  config=configs/v1_5-mix-medium-mitch-ish-s3.yaml

  # Install aws cli
  apt-get update
  apt-get install zip unzip
  curl "https://awscli.amazonaws.com/awscli-exe-linux-x86_64.zip" -o "awscliv2.zip"
  unzip awscliv2.zip
  sudo ./aws/install

  cd LLM

  pip freeze

  # Prepare environment including AWS config files for both S3 and R2 access.
  mkdir -p /root/.cache/torch
  mkdir /root/checkpoint-unsharded
  mkdir /root/data
  mkdir /root/.aws
  touch /root/.aws/credentials /root/.aws/config
  echo '[s3]' >> /root/.aws/credentials
  echo "aws_access_key_id = ${AWS_ACCESS_KEY_ID}" >> /root/.aws/credentials
  echo "aws_secret_access_key = ${AWS_SECRET_ACCESS_KEY}" >> /root/.aws/credentials
  echo '' >> /root/.aws/credentials
  echo '[r2]' >> /root/.aws/credentials
  echo "aws_access_key_id = ${R2_ACCESS_KEY_ID}" >> /root/.aws/credentials
  echo "aws_secret_access_key = ${R2_SECRET_ACCESS_KEY}" >> /root/.aws/credentials
  echo "[default]" >> /root/.aws/config
  echo "region = auto" >> /root/.aws/config
  echo "output = json" >> /root/.aws/config

  #export S3_PROFILE=s3
  #export R2_PROFILE=r2
  export OMP_NUM_THREADS=8
  export LOG_FILTER_TYPE=local_rank0_only

  # Download checkpoint.
  echo "Downloading checkpoint '${checkpoint}'..."

  # Download config.
  aws s3 cp --profile=r2 --region=auto \
    --endpoint-url=https://a198dc34621661a1a66a02d6eb7c4dc3.r2.cloudflarestorage.com \
    "${checkpoint}/config.yaml" /root/checkpoint-unsharded/

  # Download trainer state.
  aws s3 cp --profile=r2 --region=auto \
    --endpoint-url=https://a198dc34621661a1a66a02d6eb7c4dc3.r2.cloudflarestorage.com \
    "${checkpoint}/train.pt" /root/checkpoint-unsharded/

  # Download model weights.
  aws s3 cp --profile=r2 --region=auto \
    --endpoint-url=https://a198dc34621661a1a66a02d6eb7c4dc3.r2.cloudflarestorage.com \
    "${checkpoint}/model.pt" /root/checkpoint-unsharded/

  # Download optim state.
  aws s3 cp --profile=r2 --region=auto \
    --endpoint-url=https://a198dc34621661a1a66a02d6eb7c4dc3.r2.cloudflarestorage.com \
    "${checkpoint}/optim.pt" /root/checkpoint-unsharded/

  # Now remove the aws configs so it doesn't mess with data loading / uploading checkpoints to/from S3.
  rm -rf /root/.aws

  torchrun \
  --master_addr "$MASTER_ADDR" \
  --master_port "$MASTER_PORT" \
  --nnodes "$NUM_NODES" \
  --node_rank "$NODE_RANK" \
  --nproc_per_node 8 \
  scripts/train.py ${config} \
    --run_name=${run_name} \
    --save_overwrite \
    --save_interval_unsharded=10000 \
    --load_path=/root/checkpoint-unsharded \
    --compile=null \
    --model.flash_attention=true \
    --activation_checkpointing=fine_grained \
    --fsdp.wrapping_strategy=size_based \
    --remote_save_folder=s3://ai2-llm/checkpoints/7b/${run_name} \
    --restore_dataloader=false \
    --eval_interval=100 \
    --data.paths=[s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample-9B/gpt-neox-20b-pii-special/data.npy,s3://ai2-llm/preprocessed/tulu-v2-sft-mixture/gpt-neox-20b-pii-special/data.npy] \
    --optimizer.learning_rate=0.000023 \
    --scheduler.t_warmup=432410 \
    --scheduler.alpha_f=0.001 \
    --scheduler.t_max=434633  # + 2223
