# Use 2.1 for orbs
version: 2.1

# -------------------------------------------------------------------------------------
# Environments to run the jobs in
# -------------------------------------------------------------------------------------
gpu: &gpu
  environment:
    CUDA_VERSION: "11.2"
  machine:
    image: ubuntu-2004-cuda-11.2:202103-01
  resource_class: gpu.nvidia.medium.multi


# -------------------------------------------------------------------------------------
# Re-usable commands
# -------------------------------------------------------------------------------------
cache_key: &cache_key cache-key-{{ .Environment.CIRCLE_JOB }}-{{ checksum ".circleci/config.yml" }}-{{ checksum "setup.py"}}

install_dep_pt1_10: &install_dep_pt1_10
  - run:
      name: Install Pytorch Dependencies
      command: |
        source activate fairseq
        pip install --upgrade setuptools
        pip install torch==1.10.1+cu111 torchaudio==0.10.1+cu111 -f https://download.pytorch.org/whl/torch_stable.html
        python -c 'import torch; print("Torch version:", torch.__version__)'

install_dep_pt1_12: &install_dep_pt1_12
  - run:
      name: Install Pytorch Dependencies
      command: |
        source activate fairseq
        pip install --upgrade setuptools
        pip install torch==1.12.1+cu116 torchaudio==0.12.1+cu116 -f https://download.pytorch.org/whl/torch_stable.html
        python -c 'import torch; print("Torch version:", torch.__version__)'

install_repo: &install_repo
  - run:
      name: Install Repository
      command: |
        source activate fairseq
        python -m pip install fairscale
        python -m pip install -e '.[dev,docs]'
        python -c 'import torch; print("Torch version:", torch.__version__)'

run_unittests: &run_unittests
  - run:
      name: Run Unit Tests
      command: |
        source activate fairseq
        pytest tests/gpu/test_binaries_gpu.py

check_nvidia_driver: &check_nvidia_driver
  - run:
      name: Check NVIDIA Driver
      working_directory: ~/
      command: |
        pyenv versions
        nvidia-smi

create_conda_env: &create_conda_env
  - run:
      name: Install and Create Conda Environment
      command: |
        curl -o ~/miniconda.sh -O  https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh
        chmod +x ~/miniconda.sh
        bash ~/miniconda.sh -b -p $HOME/miniconda
        rm ~/miniconda.sh
        echo 'export PATH=$HOME/miniconda/bin:$PATH' >> $BASH_ENV
        source $BASH_ENV
        if [ ! -d ~/miniconda/envs/fairseq ]
        then
          conda create -y -n fairseq python=3.8
        fi
        source activate fairseq
        python --version
        pip install --upgrade pip
# -------------------------------------------------------------------------------------
# Jobs to run
# -------------------------------------------------------------------------------------

jobs:

  gpu_tests_pt1_10:
    <<: *gpu

    working_directory: ~/fairseq-py

    steps:
      - checkout
      - <<: *check_nvidia_driver
      - <<: *create_conda_env
      - restore_cache:
          key: *cache_key
      - <<: *install_dep_pt1_10
      - save_cache:
          paths:
            - ~/miniconda/
          key: *cache_key
      - <<: *install_repo
      - <<: *run_unittests

  gpu_tests_pt1_12:
    <<: *gpu

    working_directory: ~/fairseq-py

    steps:
      - checkout
      - <<: *check_nvidia_driver
      - <<: *create_conda_env
      - restore_cache:
          key: *cache_key
      - <<: *install_dep_pt1_12
      - save_cache:
          paths:
            - ~/miniconda/
          key: *cache_key
      - <<: *install_repo
      - <<: *run_unittests

workflows:
  version: 2
  build:
    jobs:
      - gpu_tests_pt1_12
      - gpu_tests_pt1_10
