type: basic
format_version: 1
maintainers: [mcore]
loggers: [stdout]
spec:
  name: '{test_case}_{environment}_{tag}'
  model: unit-tests
  nodes: 1
  build: mcore-pyt-{environment}
  gpus: 8
  platforms: dgx_h100
  script: |-
    ls

    export TAG={tag}
    export ENVIRONMENT={environment}
    export BUCKET="{test_case}"
    export UNIT_TEST_REPEAT={n_repeat}
    export UNIT_TEST_TIMEOUT=10

    set -euxo pipefail

    if [[ "$TAG" == "latest" ]]; then
      TEST_PATH="/opt/megatron-lm"
    else
      TEST_PATH="/opt/megatron-lm-legacy/"
    fi

    cd $TEST_PATH

    MARKER=()
    if [[ "$TAG" == "legacy" ]]; then
      MARKER+=("not internal")
    fi

    if [[ "$ENVIRONMENT" == "lts" ]]; then
      MARKER+=("not flaky")
    fi

    if [[ "$ENVIRONMENT" == "dev" ]]; then
      MARKER+=("not flaky_in_dev")
    fi

    MARKER_ARG=$(printf "%s" "${{MARKER[0]}}")
    for element in "${{MARKER[@]:1}}"; do
      MARKER_ARG+=" and $element"
    done

    IGNORE_TEST_CASES=$(cat /opt/megatron-lm/tests/test_utils/recipes/unit-tests.yaml | \
      yq eval '
        with(.products[].test_case; del(.[] | select(. == env(BUCKET)))) 
        | .products[].test_case[]
      ' \
      | tr " " "\n"
    )

    IGNORE_ARGS=()
    while IFS= read -r test_case; do
      if [[ $test_case == *\** ]]; then
          FILES=($(ls $test_case))
          echo ${{FILES[@]}}
          for file in "${{FILES[@]}}"; do
            IGNORE_ARGS+=("--ignore='$file'")
          done          
      else
          IGNORE_ARGS+=("--ignore=$test_case")
      fi
    done <<< "$IGNORE_TEST_CASES"

    echo "------ARGUMENTS for SLURM ---"
    MASTER_ADDR=${{MASTER_ADDR:-localhost}}
    MASTER_PORT=${{MASTER_PORT:-6000}}
    NUM_NODES=${{NUM_NODES:-${{SLURM_NNODES}}}}
    GPUS_PER_NODE=${{GPUS_PER_NODE:-8}}
    NODE_RANK=${{SLURM_NODEID:-${{SLURM_NODEID}}}}
    DISTRIBUTED_ARGS=(
        --nproc_per_node $GPUS_PER_NODE
        --nnodes $NUM_NODES
        --master_addr $MASTER_ADDR
        --master_port $MASTER_PORT
        --node_rank $SLURM_NODEID
        --log-dir {assets_dir}
        --tee "0:3"
        --redirects "3"
    )

    # Reduce memory usage by NCCL
    export NCCL_MAX_NCHANNELS=1
    export NCCL_NVLS_ENABLE=0

    for i in $(seq $UNIT_TEST_REPEAT); do
      echo "Running prod test suite."
      CMD=$(echo torchrun ${{DISTRIBUTED_ARGS[@]}} -m pytest \
        -xvs \
        --cov-report=term \
        --cov-branch \
        --cov=megatron/core \
        --cov-report xml:coverage.xml \
        --no-cov-on-fail ${{IGNORE_ARGS[@]}} \
        -m "'not experimental and ${{MARKER_ARG}}'" $BUCKET)
      eval "$CMD"

      if [[ "$TAG" == "latest" ]]; then
        CMD=$(echo torchrun ${{DISTRIBUTED_ARGS[@]}} -m pytest \
          -xvs \
          --experimental \
          -m "'experimental and ${{MARKER_ARG}}'" $BUCKET)
        
        eval "$CMD"
      fi
          
    done

    ls -al 
    cp .coverage_0 {assets_dir}/coverage_report
    cp coverage.xml {assets_dir}

products:
  - test_case: [tests/unit_tests/data/]
    products:
      - environment: [lts, dev]
        tag: [latest, legacy]
        scope: [unit-tests]
        n_repeat: [1]
        time_limit: [1800]
  - test_case: [tests/unit_tests/dist_checkpointing/*.py]
    products:
      - environment: [lts, dev]
        tag: [latest, legacy]
        scope: [unit-tests]
        n_repeat: [1]
        time_limit: [1800]
  - test_case: [tests/unit_tests/dist_checkpointing/models/]
    products:
      - environment: [lts, dev]
        tag: [latest, legacy]
        scope: [unit-tests]
        n_repeat: [1]
        time_limit: [1800]
  - test_case: [tests/unit_tests/transformer/*.py]
    products:
      - environment: [lts, dev]
        tag: [latest, legacy]
        scope: [unit-tests]
        n_repeat: [1]
        time_limit: [1800]
  - test_case: [tests/unit_tests/transformer/moe]
    products:
      - environment: [lts, dev]
        tag: [latest, legacy]
        scope: [unit-tests]
        n_repeat: [1]
        time_limit: [1800]
  - test_case: [tests/unit_tests]
    products:
      - environment: [lts, dev]
        tag: [latest, legacy]
        scope: [unit-tests]
        n_repeat: [1]
        time_limit: [1800]
