include:
  - remote: 'https://gitlab.com/cscs-ci/recipes/-/raw/master/templates/v2/.ci-ext.yml'

stages:
  - build_base
  - build
  - test

# Build base image from Containerfile.base
# The base image is re-built whenever Containerfile.base changes
build base:
  stage: build_base
  extends: .container-builder-cscs-gh200
  before_script:
    - TAG_DOCKERFILE=`sha256sum $DOCKERFILE | head -c 8`
    - TAG=${TAG_DOCKERFILE}
    - export PERSIST_IMAGE_NAME=$CSCS_REGISTRY_PATH/base/mtt-base:$TAG
    - echo "BASE_IMAGE=$PERSIST_IMAGE_NAME" > build.env
    - 'echo "INFO: Building image $PERSIST_IMAGE_NAME"'
  artifacts:
    reports:
      dotenv: build.env
  variables:
    DOCKERFILE: ci/containers/Containerfile.base

# Build image with metatrain
build mtt:
  stage: build
  extends: .container-builder-cscs-gh200
  before_script:
    - 'echo "INFO: Building image $PERSIST_IMAGE_NAME"'
  variables:
    GIT_STRATEGY: clone
    PERSIST_IMAGE_NAME: $CSCS_REGISTRY_PATH/base/mtt:$CI_COMMIT_SHORT_SHA
    DOCKERFILE: ci/containers/Containerfile.mtt
    DOCKER_BUILD_ARGS: '["BASE_IMAGE"]'

# Run tox tests (single GPU)
test tox:
  stage: test
  extends: .container-runner-daint-gh200
  image: $BASE_IMAGE
  timeout: 1h
  script:
    - tox -r
  variables:
    SLURM_JOB_NUM_NODES: 1
    SLURM_PARTITION: normal
    SLURM_NTASKS: 1
    SLURM_TIMELIMIT: '01:00:00'
    GIT_STRATEGY: fetch

# Run distributed training (single node, multiple GPUs)
test distributed:
  stage: test
  extends: .container-runner-daint-gh200
  image: $CSCS_REGISTRY_PATH/base/mtt:$CI_COMMIT_SHORT_SHA
  timeout: 1h
  script:
    - . /mtt-venv/bin/activate
    - cd /metatrain/tests/distributed
    - mtt train options-distributed.yaml
    - 'if [ $SLURM_PROCID == 0 ]; then grep "Training on 4 devices" outputs/*/*/train.log || { echo "FAILED: Not training on 4 devices as requested."; exit 1; }; fi'
  variables:
    SLURM_JOB_NUM_NODES: 1
    SLURM_PARTITION: normal
    SLURM_NTASKS: 4
    SLURM_GPUS_PER_TASK: 1
    SLURM_CPUS_PER_TASK: 16
    SLURM_TIMELIMIT: '00:10:00'
    USE_NCCL: 'cuda12'
