name: CI

on: [push, pull_request]

jobs:
  fastmri:
    runs-on: ubuntu-latest
    steps:
    - uses: actions/checkout@v2
    - name: Set up Python 3.9
      uses: actions/setup-python@v2
      with:
        python-version: 3.9
    - name: Install Modules and Run
      run: |
        pip install .[jax_cpu]
        pip install .[pytorch_cpu]
        pip install .[full]
        pip install -e .
        python tests/reference_algorithm_tests.py --workload=fastmri     --framework=pytorch     --global_batch_size=8     --submission_path=reference_algorithms/target_setting_algorithms/pytorch_momentum.py     --tuning_search_space=reference_algorithms/target_setting_algorithms/fastmri/tuning_search_space.json
        python tests/reference_algorithm_tests.py --workload=fastmri     --framework=jax     --global_batch_size=8     --submission_path=reference_algorithms/target_setting_algorithms/jax_momentum.py     --tuning_search_space=reference_algorithms/target_setting_algorithms/fastmri/tuning_search_space.json
  wmt_jax:
    runs-on: ubuntu-latest
    steps:
    - uses: actions/checkout@v2
    - name: Set up Python 3.9
      uses: actions/setup-python@v2
      with:
        python-version: 3.9
    - name: Install Modules and Run
      run: |
        pip install .[jax_cpu]
        pip install .[pytorch_cpu]
        pip install .[full]
        pip install -e .
        python tests/reference_algorithm_tests.py --workload=wmt     --framework=jax     --global_batch_size=2     --submission_path=reference_algorithms/target_setting_algorithms/jax_nadamw.py     --tuning_search_space=reference_algorithms/target_setting_algorithms/wmt/tuning_search_space.json
  wmt_pytorch:
    runs-on: ubuntu-latest
    steps:
    - uses: actions/checkout@v2
    - name: Set up Python 3.9
      uses: actions/setup-python@v2
      with:
        python-version: 3.9
    - name: Install Modules and Run
      run: |
        pip install .[jax_cpu]
        pip install .[pytorch_cpu]
        pip install .[full]
        pip install -e .
        python tests/reference_algorithm_tests.py --workload=wmt     --framework=pytorch     --global_batch_size=2     --submission_path=reference_algorithms/target_setting_algorithms/pytorch_nadamw.py     --tuning_search_space=reference_algorithms/target_setting_algorithms/wmt/tuning_search_space.json
  imagenet_jax:
    runs-on: ubuntu-latest
    steps:
    - uses: actions/checkout@v2
    - name: Set up Python 3.9
      uses: actions/setup-python@v2
      with:
        python-version: 3.9
    - name: Install Modules and Run
      run: |
        pip install .[jax_cpu]
        pip install .[pytorch_cpu]
        pip install .[full]
        pip install -e .
        python tests/reference_algorithm_tests.py --workload=imagenet_vit     --framework=jax     --global_batch_size=2     --submission_path=reference_algorithms/target_setting_algorithms/jax_adamw.py     --tuning_search_space=reference_algorithms/target_setting_algorithms/imagenet_vit/tuning_search_space.json
        python tests/reference_algorithm_tests.py --workload=imagenet_resnet     --framework=jax     --global_batch_size=2     --submission_path=reference_algorithms/target_setting_algorithms/jax_momentum.py     --tuning_search_space=reference_algorithms/target_setting_algorithms/imagenet_resnet/tuning_search_space.json
  imagenet_pytorch:
    runs-on: ubuntu-latest
    steps:
    - uses: actions/checkout@v2
    - name: Set up Python 3.9
      uses: actions/setup-python@v2
      with:
        python-version: 3.9
    - name: Install Modules and Run
      run: |
        pip install .[jax_cpu]
        pip install .[pytorch_cpu]
        pip install .[full]
        pip install -e .
        python tests/reference_algorithm_tests.py --workload=imagenet_resnet     --framework=pytorch     --global_batch_size=2     --submission_path=reference_algorithms/target_setting_algorithms/pytorch_momentum.py     --tuning_search_space=reference_algorithms/target_setting_algorithms/imagenet_resnet/tuning_search_space.json
        python tests/reference_algorithm_tests.py --workload=imagenet_vit     --framework=pytorch     --global_batch_size=2     --submission_path=reference_algorithms/target_setting_algorithms/pytorch_adamw.py     --tuning_search_space=reference_algorithms/target_setting_algorithms/imagenet_vit/tuning_search_space.json
  # uncomment when https://github.com/mlcommons/algorithmic-efficiency/issues/339 is resolved.
  criteo_jax:
    runs-on: ubuntu-latest
    steps:
    - uses: actions/checkout@v2
    - name: Set up Python 3.9
      uses: actions/setup-python@v2
      with:
        python-version: 3.9
    - name: Install Modules and Run
      run: |
        pip install .[jax_cpu]
        pip install .[pytorch_cpu]
        pip install .[full]
        pip install -e .
        python tests/reference_algorithm_tests.py --workload=criteo1tb_test     --framework=jax     --global_batch_size=1     --submission_path=reference_algorithms/target_setting_algorithms/jax_adamw.py     --tuning_search_space=reference_algorithms/target_setting_algorithms/criteo1tb/tuning_search_space.json
  criteo_pytorch:
    runs-on: ubuntu-latest
    steps:
    - uses: actions/checkout@v2
    - name: Set up Python 3.9
      uses: actions/setup-python@v2
      with:
        python-version: 3.9
    - name: Install Modules and Run
      run: |
        pip install .[jax_cpu]
        pip install .[pytorch_cpu]
        pip install .[full]
        pip install -e .
        python tests/reference_algorithm_tests.py --workload=criteo1tb_test     --framework=pytorch     --global_batch_size=1     --submission_path=reference_algorithms/target_setting_algorithms/pytorch_adamw.py     --tuning_search_space=reference_algorithms/target_setting_algorithms/criteo1tb/tuning_search_space.json
  speech_jax:
    runs-on: ubuntu-latest
    steps:
    - uses: actions/checkout@v2
    - name: Set up Python 3.9
      uses: actions/setup-python@v2
      with:
        python-version: 3.9
    - name: Install Modules and Run
      run: |
        pip install .[jax_cpu]
        pip install .[pytorch_cpu]
        pip install .[full]
        pip install -e .
        python tests/reference_algorithm_tests.py --workload=librispeech_conformer     --framework=jax     --global_batch_size=2     --submission_path=reference_algorithms/target_setting_algorithms/jax_adamw.py     --tuning_search_space=reference_algorithms/target_setting_algorithms/librispeech_conformer/tuning_search_space.json
        python tests/reference_algorithm_tests.py --workload=librispeech_deepspeech     --framework=jax     --global_batch_size=2     --submission_path=reference_algorithms/target_setting_algorithms/jax_adamw.py     --tuning_search_space=reference_algorithms/target_setting_algorithms/librispeech_deepspeech/tuning_search_space.json
  speech_pytorch:
    runs-on: ubuntu-latest
    steps:
    - uses: actions/checkout@v2
    - name: Set up Python 3.9
      uses: actions/setup-python@v2
      with:
        python-version: 3.9
    - name: Install Modules and Run
      run: |
        pip install .[jax_cpu]
        pip install .[pytorch_cpu]
        pip install .[full]
        pip install -e .
        python tests/reference_algorithm_tests.py --workload=librispeech_deepspeech     --framework=pytorch     --global_batch_size=2     --submission_path=reference_algorithms/target_setting_algorithms/pytorch_adamw.py     --tuning_search_space=reference_algorithms/target_setting_algorithms/librispeech_deepspeech/tuning_search_space.json
        python tests/reference_algorithm_tests.py --workload=librispeech_conformer     --framework=pytorch     --global_batch_size=2     --submission_path=reference_algorithms/target_setting_algorithms/pytorch_adamw.py     --tuning_search_space=reference_algorithms/target_setting_algorithms/librispeech_conformer/tuning_search_space.json
  ogbg:
    runs-on: ubuntu-latest
    steps:
    - uses: actions/checkout@v2
    - name: Set up Python 3.9
      uses: actions/setup-python@v2
      with:
        python-version: 3.9
    - name: Install Modules and Run
      run: |
        pip install .[jax_cpu]
        pip install .[pytorch_cpu]
        pip install .[full]
        pip install -e .
        python tests/reference_algorithm_tests.py --workload=ogbg     --framework=pytorch     --global_batch_size=8     --submission_path=reference_algorithms/target_setting_algorithms/pytorch_nesterov.py     --tuning_search_space=reference_algorithms/target_setting_algorithms/ogbg/tuning_search_space.json
        python tests/reference_algorithm_tests.py --workload=ogbg     --framework=jax     --global_batch_size=8     --submission_path=reference_algorithms/target_setting_algorithms/jax_nesterov.py     --tuning_search_space=reference_algorithms/target_setting_algorithms/ogbg/tuning_search_space.json
  pytest:
    runs-on: ubuntu-latest
    steps:
    - uses: actions/checkout@v2
    - name: Set up Python 3.9
      uses: actions/setup-python@v2
      with:
        python-version: 3.9
    - name: Install pytest
      run: |
        python -m pip install --upgrade pip
        pip install pytest
        pip install .[full]
        pip install .[jax_cpu]
        pip install .[pytorch_cpu]
    - name: Run pytest
      run: |
        pytest -vx tests/version_test.py
        pytest -vx tests/test_num_params.py
        pytest -vx tests/test_param_shapes.py
        pytest -vx tests/test_param_types.py
        pytest -vx tests/test_ssim.py