name: Run Flax dependency tests

on:
  pull_request:
    branches:
      - main
    paths:
      - "src/diffusers/**.py"
  push:
    branches:
      - main

concurrency:
  group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
  cancel-in-progress: true

jobs:
  check_flax_dependencies:
    runs-on: ubuntu-22.04
    steps:
      - uses: actions/checkout@v3
      - name: Set up Python
        uses: actions/setup-python@v4
        with:
          python-version: "3.8"
      - name: Install dependencies
        run: |
          python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
          python -m pip install --upgrade pip uv
          python -m uv pip install -e .
          python -m uv pip install "jax[cpu]>=0.2.16,!=0.3.2"
          python -m uv pip install "flax>=0.4.1"
          python -m uv pip install "jaxlib>=0.1.65"
          python -m uv pip install pytest
      - name: Check for soft dependencies
        run: |
          python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
          pytest tests/others/test_dependencies.py
