description: listwise_loss_weight

# target:
#   service: aml
#   name: alta2

# environment:
#   image: nvcr:v23.10
#   registry: shumingdocker.azurecr.io
#   setup:
#   - echo "master_addr:" "$$MASTER_ADDR"
#   - echo "master_port:" $$MASTER_PORT
#   - echo "node_rank:" $$OMPI_COMM_WORLD_RANK
#   username: shumingdocker

target:
  service: sing
  name: msroctovc
  resource_group: gcr-singularity-octo
  # workspace_name: Workspace_NLC
  workspace_name: NLC_Workspace

# target:
#   service: sing
#   name: msrresrchvc
#   resource_group: gcr-singularity-resrch
#   workspace_name: Workspace_NLC


# target:
#   service: sing
#   name: msroctobasicvc
#   resource_group: gcr-singularity-octo
#   workspace_name: Workspace_NLC


environment:
  # image: hangbo/pytorch-2.23dev:xformers
  # image: hangbo/pytorch-2.23dev:xformers_s2
  image: amlt-sing/acpt-2.3.1-py3.10-cuda12.1
  setup:
  - echo "master_addr:" "$$MASTER_ADDR"
  - echo "master_port:" $$MASTER_PORT
  - echo "node_rank:" $$OMPI_COMM_WORLD_RANK



code:
  local_dir: $CONFIG_DIR/..

storage:
  lingjiejiang:
    storage_account_name: msranlpintern
    container_name: lingjiejiang
#   msranlp:
#     storage_account_name: msranlp
#     container_name: unilm
#   nlcredstone:
#     storage_account_name: nlcredstone
#     container_name: unilm
#   conversationhub:
#     storage_account_name: conversationhub
#     container_name: unilm
#   conversationhubhot:
#     storage_account_name: conversationhubhot
#     container_name: tengchaolv


search:
  job_template:
    name: listwise_weight
    sku: 1x40G8
    # sku: 2x40G8-A100-IB-NvLink
    identity: managed
    mpi: True
    process_count_per_node: 1
    command:
    - pip install -e ".[torch,metrics]"
    - pip install deepspeed==0.14.4
    - echo $${rank}
    # - bash amlt_job/sas_mount.sh
    # - pip install -U flash-attn --no-build-isolation
    - pip install -U flash-attn==2.4.2 --no-build-isolation
    - pip install tensorboard
    - export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
    - ls /mnt/lingjiejiang
    # - ls /mnt/msranlp
    # - ls /mnt/nlcredstone
    # - ls /mnt/conversationhub
    # - ls /mnt/conversationhubhot
    # - FORCE_TORCHRUN=1 llamafactory-cli train bash_script/{rank}
    # - torchrun --nproc_per_node=8 --nnodes=2 --node_rank=$$OMPI_COMM_WORLD_RANK --master_addr="$$MASTER_ADDR" --master_port=$$MASTER_PORT src/train.py bash_script/{rank}
    # - sleep infinity
    - torchrun --nproc_per_node=8 src/train.py bash_script/{rank}
    priority: High
    submit_args:
      env:
        {"SINGULARITY_MPI_ENV":"-mca pml ucx --mca btl ^vader,tcp,openib -x NCCL_SOCKET_IFNAME=bond0 -x NCCL_IB_HCA=mlx5_0,mlx5_3,mlx5_4,mlx5_5,mlx5_6,mlx5_9,mlx5_10,mlx5_11 -x NCCL_DEBUG=INFO"}
      container_args:
        shm_size: 256g
    tags: [Project_Name:1.58-bit-LLMs, ProjectID:PRJ-0349-A54, Experiment:BitNet-scaling]

  type: grid
  max_trials: 500
  params:
    - name: rank
      spec: discrete
      # values: ['ta_chosen_llama3.1_instruct_dpo_2048_default_template_job.yaml', 'ta_chosen_tuluv2_dpo_2048_default_template_job.yaml', 'ta_rejected_llama3.1_instruct_dpo_2048_default_template_job.yaml', 'ta_rejected_tuluv2_dpo_2048_default_template_job.yaml']
      # values: ['ta_rejected_tuluv2_dpo_2048_default_template_job.yaml', 'ta_rejected_tuluv2_dpo_2048_default_template_v2_job.yaml', 'ta_rejected_tuluv2_dpo_2048_default_template_v3_job.yaml']
      # values: ['glanchatv2_full_sft_2048_default_template_job_lr5e6_e3.yaml', 'glanv2_full_sft_2048_default_template_job_lr5e6_e3.yaml', 'glanv2_glanchatv2_full_sft_2048_default_template_job_lr5e6_e3.yaml']
      values: ["uf_llama3.1_instruct_dpo_2048_job_ta_rejected.yaml"]

